def get_data_loaders(self):
        transform = self.TRANSFORM

        test_transform = transforms.Compose(
            [transforms.ToTensor(), self.get_normalization_transform()])

        train_dataset = MyTinyImagenet(base_path() + 'TINYIMG',
                                 train=True, download=True, transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(train_dataset,
                                                    test_transform, self.NAME)
        else:
            test_dataset = TinyImagenet(base_path() + 'TINYIMG',
                        train=False, download=True, transform=test_transform)

        train, test = store_masked_loaders(train_dataset, test_dataset, self)
        return train, test
    def get_data_loaders(self, nomask=False):
        transform = transforms.ToTensor()
        train_dataset = MyFMNIST(base_path() + 'FMNIST',
                                 train=True,
                                 download=True,
                                 transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(
                train_dataset, transform, self.NAME)
        else:
            test_dataset = FMNIST(base_path() + 'FMNIST',
                                  train=False,
                                  download=True,
                                  transform=transform)

        if not nomask:
            train, test = store_masked_loaders(train_dataset, test_dataset,
                                               self)
            return train, test
        else:
            return train_dataset, test_dataset
Exemple #3
0
    def get_data_loaders(self, nomask=False):
        transform = self.TRANSFORM

        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             self.get_normalization_transform()])

        train_dataset = MyCIFAR10(base_path() + 'CIFAR10',
                                  train=True,
                                  download=True,
                                  transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(
                train_dataset, test_transform, self.NAME)
        else:
            test_dataset = CIFAR10(base_path() + 'CIFAR10',
                                   train=False,
                                   download=True,
                                   transform=test_transform)

        if not nomask:
            if isinstance(train_dataset.targets, list):
                train_dataset.targets = torch.tensor(train_dataset.targets,
                                                     dtype=torch.long)
            if isinstance(test_dataset.targets, list):
                test_dataset.targets = torch.tensor(test_dataset.targets,
                                                    dtype=torch.long)
            train, test = store_masked_loaders(train_dataset, test_dataset,
                                               self)
            return train, test
        else:
            train_loader = DataLoader(train_dataset,
                                      batch_size=32,
                                      shuffle=True,
                                      num_workers=2)
            test_loader = DataLoader(test_dataset,
                                     batch_size=32,
                                     shuffle=False,
                                     num_workers=2)
            return train_loader, test_loader