import numpy as np
import torch
import torch.utils.data
import os
import random
The dataset is a Pytorch class. The datasets objects, used along a dataloader provide de data in a pytorch format. The following class, takes numpy arrays for the input data and the targets. It is also possible to operate a crop on the data.
Note: This class does not operate a data normalization, normalization must be either done before creating the dataset or modify the definition of the class.
class ImageDataset(torch.utils.data.Dataset):
"""Main Class for Image Folder loader."""
def __init__(self, data, targets, crop=False, imsize=256):
"""Init function."""
self.data = data
self.targets = targets
self.crop = crop
self.imsize = imsize
def __getitem__(self, index):
"""Get item."""
data, target = self.data[index], self.targets[index]
if self.crop:
w, h, _ = data.shape
x1 = random.randint(0, w - self.imsize)
y1 = random.randint(0, h - self.imsize)
data = data[x1:x1+self.imsize, y1:y1+self.imsize]
target = target[x1:x1+self.imsize, y1:y1+self.imsize]
# in troch channels are first
data = data.transpose(2,0,1)
target = target.transpose(2,0,1)
# convert to float32
data = data.astype(np.float32)
target = target.astype(np.float32)
# convert to torch tensors
data = torch.from_numpy(data)
target = torch.from_numpy(target)
return data, target
def __len__(self):
"""Length."""
return self.data.shape[0]
We provide the data in the form of numpy arrays, you can download them here:
Or from here (Mirror):
Supposing they are stored on you Google Drive in the data/dehazing
folder, you can mount the folder using the following code. Set USE_COLAB
to true.
USE_COLAB = False
if USE_COLAB:
# mount the goole drive
from google.colab import drive
drive.mount('/content/drive')
# download cifar on GoogleDrive
data_dir = "/content/drive/My Drive/data/dehazing"
else:
data_dir = "data/dehazing"
data_th = np.load(os.path.join(data_dir, "train_data.npy"))
gt_th = np.load(os.path.join(data_dir, "train_gt.npy"))
dataset = ImageDataset(data_th, gt_th, crop=True, imsize=256)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
for inputs, targets in dataloader:
print(inputs.size(), targets.size())
val_th = np.load(os.path.join(data_dir, "val_data.npy"))