def prepare_data(self): """ Saves MNIST files to data_dir """ BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())
def train_dataloader(self, batch_size=32, transforms=None): """ MNIST train set removes a subset to use for validation Args: batch_size: size of batch transforms: custom transforms """ transforms = transforms or self.train_transforms or self._default_transforms( ) dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split( dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=True) return loader
def test_dataloader(self): """ MNIST test set uses the test split """ transforms = self._default_transforms( ) if self.test_transforms is None else self.test_transforms dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True) return loader
def test_dataloader(self, batch_size=32, transforms=None): """ MNIST test set uses the test split Args: batch_size: size of batch transforms: custom transforms """ transforms = transforms or self.val_transforms or self._default_transforms() dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True ) return loader
def val_dataloader(self): """ MNIST val set uses a subset of the training set for validation """ transforms = self._default_transforms( ) if self.val_transforms is None else self.val_transforms dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) _, dataset_val = random_split( dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True) return loader
def train_dataloader(self): """ MNIST train set removes a subset to use for validation """ transforms = self._default_transforms( ) if self.train_transforms is None else self.train_transforms dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split( dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory) return loader