Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
import torch
import jovian
import torchvision
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torch.utils.data import DataLoader, TensorDataset, random_split
batch_size = 100
learning_rate = 0.001
# Other constants
input_size = 3*32*32
num_classes = 10
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
testset = torchvision.datasets.CIFAR10(root='D:\PyTorch\cifar-10-python', train=False,download=False, transform=transform)
trainvalset = torchvision.datasets.CIFAR10(root='D:\PyTorch\cifar-10-python', train=True,download=False, transform=transform)
trainset, valset = torch.utils.data.random_split(trainvalset, [45000, 5000]) # 10% for validation
train_loader = torch.utils.data.DataLoader(trainset, batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(valset, batch_size, shuffle=False)
'''
transform_train = transforms.Compose([transforms.Resize((32,32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform = transforms.Compose([transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
training_dataset = CIFAR10(root='D:\PyTorch\cifar-10-python', train=True, download=True, transform=transform_train)
train_ds, val_ds = random_split(training_dataset, [40000, 10000])
test_ds = CIFAR10(root='D:\PyTorch\cifar-10-python', train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size , shuffle = False)
test_loader = DataLoader(test_ds, batch_size, shuffle=False)
'''
"\ntransform_train = transforms.Compose([transforms.Resize((32,32)),\n transforms.RandomHorizontalFlip(),\n transforms.RandomRotation(10),\n transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),\n transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),\n transforms.ToTensor(),\n transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n ])\n\n\ntransform = transforms.Compose([transforms.Resize((32,32)),\n transforms.ToTensor(),\n transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n ])\ntraining_dataset = CIFAR10(root='D:\\PyTorch\\cifar-10-python', train=True, download=True, transform=transform_train)\ntrain_ds, val_ds = random_split(training_dataset, [40000, 10000])\ntest_ds = CIFAR10(root='D:\\PyTorch\\cifar-10-python', train=False, download=True, transform=transform)\n\ntrain_loader = DataLoader(train_ds, batch_size, shuffle=True)\nval_loader = DataLoader(val_ds, batch_size , shuffle = False)\ntest_loader = DataLoader(test_ds, batch_size, shuffle=False)\n"
image, label = train_ds[0]
plt.imshow(image[0], cmap='gray')
print('Label:', label)
Label: 9