Learn practical skills, build real-world projects, and advance your career
import os
from random import randrange
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline
data_dir = '../input/kermany2018/oct2017/OCT2017'
classes= os.listdir(data_dir+" /train")
num_classes = len(classes)
print("number of diseases:", num_classes)
print({cls: len(os.listdir(data_dir + f" /train/{cls}/")) for cls in sorted(classes)})
number of diseases: 4 {'CNV': 37205, 'DME': 11348, 'DRUSEN': 8616, 'NORMAL': 26315}
train_tfms =tt.Compose([tt.Compose([
    tt.Resize(256),
    tt.CenterCrop(224),
    tt.ToTensor(),
    tt.Normalize(mean=0.1817, std=0.1797)
])])
valid_tfms=tt.Compose([tt.Compose([
    tt.Resize(256),
    tt.CenterCrop(224),
    tt.ToTensor(),
    tt.Normalize(mean=0.1817, std=0.1797)
])])


# Create datasets
train_ds = ImageFolder(data_dir+' /train', train_tfms)
valid_ds = ImageFolder(data_dir+' /test', valid_tfms)

# set the batch size
batch_size = 100

# PyTorch data loaders
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=0, pin_memory=True)
def show_example(data):
    [img, label] = data
    print(classes[label])
    plt.imshow(img.permute(1, 2, 0))
    
# show an image!
image_number = randrange(20000)
show_example(train_ds[image_number])
DRUSEN
Notebook Image
def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images[:10], nrow=10).permute(1, 2, 0))
        break