예제 #1
0
 def prepare_data(self):
     """
     Saves MNIST files to data_dir
     """
     BinaryMNIST(self.data_dir,
                 train=True,
                 download=True,
                 transform=transform_lib.ToTensor())
     BinaryMNIST(self.data_dir,
                 train=False,
                 download=True,
                 transform=transform_lib.ToTensor())
    def train_dataloader(self, batch_size=32, transforms=None):
        """
        MNIST train set removes a subset to use for validation

        Args:
            batch_size: size of batch
            transforms: custom transforms
        """
        transforms = transforms or self.train_transforms or self._default_transforms(
        )

        dataset = BinaryMNIST(self.data_dir,
                              train=True,
                              download=False,
                              transform=transforms)
        train_length = len(dataset)
        dataset_train, _ = random_split(
            dataset, [train_length - self.val_split, self.val_split],
            generator=torch.Generator().manual_seed(self.seed))
        loader = DataLoader(dataset_train,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
예제 #3
0
    def test_dataloader(self):
        """
        MNIST test set uses the test split
        """
        transforms = self._default_transforms(
        ) if self.test_transforms is None else self.test_transforms

        dataset = BinaryMNIST(self.data_dir,
                              train=False,
                              download=False,
                              transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
예제 #4
0
    def test_dataloader(self, batch_size=32, transforms=None):
        """
        MNIST test set uses the test split

        Args:
            batch_size: size of batch
            transforms: custom transforms
        """
        transforms = transforms or self.val_transforms or self._default_transforms()

        dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms)
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True
        )
        return loader
예제 #5
0
 def val_dataloader(self):
     """
     MNIST val set uses a subset of the training set for validation
     """
     transforms = self._default_transforms(
     ) if self.val_transforms is None else self.val_transforms
     dataset = BinaryMNIST(self.data_dir,
                           train=True,
                           download=False,
                           transform=transforms)
     train_length = len(dataset)
     _, dataset_val = random_split(
         dataset, [train_length - self.val_split, self.val_split],
         generator=torch.Generator().manual_seed(self.seed))
     loader = DataLoader(dataset_val,
                         batch_size=self.batch_size,
                         shuffle=False,
                         num_workers=self.num_workers,
                         drop_last=True,
                         pin_memory=True)
     return loader
예제 #6
0
    def train_dataloader(self):
        """
        MNIST train set removes a subset to use for validation
        """
        transforms = self._default_transforms(
        ) if self.train_transforms is None else self.train_transforms

        dataset = BinaryMNIST(self.data_dir,
                              train=True,
                              download=False,
                              transform=transforms)
        train_length = len(dataset)
        dataset_train, _ = random_split(
            dataset, [train_length - self.val_split, self.val_split],
            generator=torch.Generator().manual_seed(self.seed))
        loader = DataLoader(dataset_train,
                            batch_size=self.batch_size,
                            shuffle=self.shuffle,
                            num_workers=self.num_workers,
                            drop_last=self.drop_last,
                            pin_memory=self.pin_memory)
        return loader