"""Dataset setting and data loader for USPS. Modified from https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py """ import gzip import os import pickle import urllib from PIL import Image import numpy as np import torch import torch.utils.data as data from torch.utils.data.sampler import WeightedRandomSampler from torchvision import datasets, transforms class USPS(data.Dataset): """USPS Dataset. Args: root (string): Root directory of dataset where dataset file exist. train (bool, optional): If True, resample from dataset randomly. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" def __init__(self, root, train=True, transform=None, download=False): """Init USPS dataset.""" # init params self.root = os.path.expanduser(root) self.filename = "usps_28x28.pkl" self.train = train # Num of Train = 7438, Num ot Test 1860 self.transform = transform self.dataset_size = None # download dataset. if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found." + " You can use download=True to download it") self.train_data, self.train_labels = self.load_samples() if self.train: total_num_samples = self.train_labels.shape[0] indices = np.arange(total_num_samples) self.train_data = self.train_data[indices[0:self.dataset_size], ::] self.train_labels = self.train_labels[indices[0:self.dataset_size]] self.train_data *= 255.0 self.train_data = np.squeeze(self.train_data).astype(np.uint8) def __getitem__(self, index): """Get images and target for data loader. Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, label = self.train_data[index], self.train_labels[index] img = Image.fromarray(img, mode='L') img = img.copy() if self.transform is not None: img = self.transform(img) return img, label.astype("int64") def __len__(self): """Return size of dataset.""" return len(self.train_data) def _check_exists(self): """Check if dataset is download and in right place.""" return os.path.exists(os.path.join(self.root, self.filename)) def download(self): """Download dataset.""" filename = os.path.join(self.root, self.filename) dirname = os.path.dirname(filename) if not os.path.isdir(dirname): os.makedirs(dirname) if os.path.isfile(filename): return print("Download %s to %s" % (self.url, os.path.abspath(filename))) urllib.request.urlretrieve(self.url, filename) print("[DONE]") return def load_samples(self): """Load sample images from dataset.""" filename = os.path.join(self.root, self.filename) f = gzip.open(filename, "rb") data_set = pickle.load(f, encoding="bytes") f.close() if self.train: images = data_set[0][0] labels = data_set[0][1] self.dataset_size = labels.shape[0] else: images = data_set[1][0] labels = data_set[1][1] self.dataset_size = labels.shape[0] return images, labels class USPS_idx(data.Dataset): """USPS Dataset. Args: root (string): Root directory of dataset where dataset file exist. train (bool, optional): If True, resample from dataset randomly. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" def __init__(self, root, train=True, transform=None, download=False): """Init USPS dataset.""" # init params self.root = os.path.expanduser(root) self.filename = "usps_28x28.pkl" self.train = train # Num of Train = 7438, Num ot Test 1860 self.transform = transform self.dataset_size = None # download dataset. if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found." + " You can use download=True to download it") self.train_data, self.train_labels = self.load_samples() if self.train: total_num_samples = self.train_labels.shape[0] indices = np.arange(total_num_samples) self.train_data = self.train_data[indices[0:self.dataset_size], ::] self.train_labels = self.train_labels[indices[0:self.dataset_size]] self.train_data *= 255.0 self.train_data = np.squeeze(self.train_data).astype(np.uint8) def __getitem__(self, index): """Get images and target for data loader. Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, label = self.train_data[index], self.train_labels[index] img = Image.fromarray(img, mode='L') img = img.copy() if self.transform is not None: img = self.transform(img) return img, label.astype("int64"), index def __len__(self): """Return size of dataset.""" return len(self.train_data) def _check_exists(self): """Check if dataset is download and in right place.""" return os.path.exists(os.path.join(self.root, self.filename)) def download(self): """Download dataset.""" filename = os.path.join(self.root, self.filename) dirname = os.path.dirname(filename) if not os.path.isdir(dirname): os.makedirs(dirname) if os.path.isfile(filename): return print("Download %s to %s" % (self.url, os.path.abspath(filename))) urllib.request.urlretrieve(self.url, filename) print("[DONE]") return def load_samples(self): """Load sample images from dataset.""" filename = os.path.join(self.root, self.filename) f = gzip.open(filename, "rb") data_set = pickle.load(f, encoding="bytes") f.close() if self.train: images = data_set[0][0] labels = data_set[0][1] self.dataset_size = labels.shape[0] else: images = data_set[1][0] labels = data_set[1][1] self.dataset_size = labels.shape[0] return images, labels