Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
import os
from random import randrange
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline
data_dir = '../input/kermany2018/oct2017/OCT2017'
classes= os.listdir(data_dir+" /train")
num_classes = len(classes)
print("number of diseases:", num_classes)
print({cls: len(os.listdir(data_dir + f" /train/{cls}/")) for cls in sorted(classes)})
number of diseases: 4
{'CNV': 37205, 'DME': 11348, 'DRUSEN': 8616, 'NORMAL': 26315}
train_tfms =tt.Compose([tt.Compose([
tt.Resize(256),
tt.CenterCrop(224),
tt.ToTensor(),
tt.Normalize(mean=0.1817, std=0.1797)
])])
valid_tfms=tt.Compose([tt.Compose([
tt.Resize(256),
tt.CenterCrop(224),
tt.ToTensor(),
tt.Normalize(mean=0.1817, std=0.1797)
])])
# Create datasets
train_ds = ImageFolder(data_dir+' /train', train_tfms)
valid_ds = ImageFolder(data_dir+' /test', valid_tfms)
# set the batch size
batch_size = 100
# PyTorch data loaders
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=0, pin_memory=True)
def show_example(data):
[img, label] = data
print(classes[label])
plt.imshow(img.permute(1, 2, 0))
# show an image!
image_number = randrange(20000)
show_example(train_ds[image_number])
DRUSEN
def show_batch(dl):
for images, labels in dl:
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(images[:10], nrow=10).permute(1, 2, 0))
break