Beispiel #1
0
 def setup(self, stage=None):
     self.usps_test = USPS(self.data_dir,
                           train=False,
                           transform=transforms.ToTensor(),
                           download=True)
     usps_full = USPS(self.data_dir,
                      train=True,
                      transform=transforms.ToTensor(),
                      download=True)
     self.usps_train, self.usps_val = random_split(usps_full, [6000, 1291])
Beispiel #2
0
def load_usps(img_size=28, augment=False, **kwargs):
    transformations = [transforms.Resize(img_size)]
    transformations.append(transforms.ToTensor())
    if augment:
        transformations.append(
            transforms.Lambda(lambda x: random_affine_augmentation(x)))
        transformations.append(transforms.Lambda(lambda x: gaussian_blur(x)))
    img_transform = transforms.Compose(transformations)
    test_transform = transforms.Compose(
        [transforms.Resize(img_size),
         transforms.ToTensor()])
    train_set = USPS('../data', transform=img_transform, download=True)
    test_set = USPS('../data', transform=test_transform, download=True)
    return get_loader(train_set, **kwargs), get_loader(test_set, **kwargs)
Beispiel #3
0
 def get_data(domain_name: str, split="train"):
     train = split == "train"
     if domain_name == "mnist":
         return do.from_pytorch(MNIST(data_path, download=True,
                                      train=train))
     if domain_name == "usps":
         return do.from_pytorch(USPS(data_path, download=True, train=train))
Beispiel #4
0
def select_test_dataset(dataset_name, testing=False):
    """
    Selects a dataset from the options below
    Parameters
    ----------
    dataset_name: dataset name given as string: 'fmnist'
    testing: testing flag. If testing is True, then the function returns 1000 samples only

    Returns
    -------
    vec_data: the dataset as a numpy array. The dimensions are N X D where N is the number of samples in the data and
    D is the dimensions of the feature vector.
    labels: the labels of the samples. The dimensions are N X 1 where N is the number of samples in the data and
    1 is label of the sample.
    """
    if dataset_name == 'fmnist':
        f_mnist = FashionMNIST(root="./datasets", train=False, download=True)
        data = f_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = f_mnist.targets.numpy()
    elif dataset_name == 'usps':
        f_mnist = USPS(root="./datasets", train=False, download=True)
        data = f_mnist.data
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = np.float32(f_mnist.targets)
    elif dataset_name == 'char':
        digits = datasets.load_digits()
        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))
        vec_data = np.float32(data)
        labels = digits.target
    elif dataset_name == 'charx':
        file_name = file_path + "/datasets/char_x.npy"
        data = np.load(file_name, allow_pickle=True)
        vec_data, labels = data[2], data[3]
    else:
        print('The dataset you asked for is not available. Gave you MNIST instead.')
        d_mnist = MNIST(root="./datasets", train=False, download=True)
        data = d_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = d_mnist.targets.numpy()

    if testing:
        return vec_data[:1000], labels[:1000]
    else:
        return vec_data, labels
Beispiel #5
0
def main():
    args = parse_args()

    if args.debug or not args.non_deterministic:
        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        # torch.set_deterministic(True) # grid_sampler_2d_backward_cuda does not have a deterministic implementation

    if args.debug:
        torch.autograd.set_detect_anomaly(True)

    dataloader_args = EasyDict(
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0 if args.debug else args.data_workers)
    if args.dataset == 'mnist':
        args.num_classes = 10
        args.im_channels = 1
        args.image_size = (40, 40)

        from torchvision.datasets import MNIST

        t = transforms.Compose([
            transforms.RandomCrop(size=(40, 40), pad_if_needed=True),
            transforms.ToTensor(),
            # norm_1c
        ])
        train_dataloader = DataLoader(
            MNIST(data_path / 'mnist', train=True, transform=t, download=True),
            **dataloader_args)
        val_dataloader = DataLoader(
            MNIST(data_path / 'mnist', train=False, transform=t,
                  download=True), **dataloader_args)
    elif args.dataset == 'usps':
        args.num_classes = 10
        args.im_channels = 1
        args.image_size = (40, 40)

        from torchvision.datasets import USPS

        t = transforms.Compose([
            transforms.RandomCrop(size=(40, 40), pad_if_needed=True),
            transforms.ToTensor(),
            # norm_1c
        ])
        train_dataloader = DataLoader(
            USPS(data_path / 'usps', train=True, transform=t, download=True),
            **dataloader_args)
        val_dataloader = DataLoader(
            USPS(data_path / 'usps', train=False, transform=t, download=True),
            **dataloader_args)
    elif args.dataset == 'constellation':

        data_gen = create_constellation(
            batch_size=args.batch_size,
            shuffle_corners=True,
            gaussian_noise=.0,
            drop_prob=0.5,
            which_patterns=[[0], [1], [0]],
            rotation_percent=180 / 360.,
            max_scale=3.,
            min_scale=3.,
            use_scale_schedule=False,
            schedule_steps=0,
        )

        train_dataloader = DataLoader(data_gen, **dataloader_args)
        val_dataloader = DataLoader(data_gen, **dataloader_args)

    elif args.dataset == 'cifar10':
        args.num_classes = 10
        args.im_channels = 3
        args.image_size = (32, 32)

        from torchvision.datasets import CIFAR10

        t = transforms.Compose([transforms.ToTensor()])
        train_dataloader = DataLoader(
            CIFAR10(data_path / 'cifar10',
                    train=True,
                    transform=t,
                    download=True), **dataloader_args)
        val_dataloader = DataLoader(
            CIFAR10(data_path / 'cifar10',
                    train=False,
                    transform=t,
                    download=True), **dataloader_args)
    elif args.dataset == 'svhn':
        args.num_classes = 10
        args.im_channels = 3
        args.image_size = (32, 32)

        from torchvision.datasets import SVHN

        t = transforms.Compose([transforms.ToTensor()])
        train_dataloader = DataLoader(
            SVHN(data_path / 'svhn', split='train', transform=t,
                 download=True), **dataloader_args)
        val_dataloader = DataLoader(
            SVHN(data_path / 'svhn', split='test', transform=t, download=True),
            **dataloader_args)
    else:
        raise NotImplementedError()

    logger = WandbLogger(project=args.log.project,
                         name=args.log.run_name,
                         entity=args.log.team,
                         config=args,
                         offline=not args.log.upload)

    if args.model == 'ccae':
        from scae.modules.attention import SetTransformer
        from scae.modules.capsule import CapsuleLayer
        from scae.models.ccae import CCAE

        encoder = SetTransformer(2)
        decoder = CapsuleLayer(input_dims=32,
                               n_caps=3,
                               n_caps_dims=2,
                               n_votes=4,
                               n_caps_params=32,
                               n_hiddens=128,
                               learn_vote_scale=True,
                               deformations=True,
                               noise_type='uniform',
                               noise_scale=4.,
                               similarity_transform=False)

        model = CCAE(encoder, decoder, args)

        # logger.watch(encoder._encoder, log='all', log_freq=args.log_frequency)
        # logger.watch(decoder, log='all', log_freq=args.log_frequency)
    elif args.model == 'pcae':
        from scae.modules.part_capsule_ae import CapsuleImageEncoder, TemplateImageDecoder
        from scae.models.pcae import PCAE

        encoder = CapsuleImageEncoder(args)
        decoder = TemplateImageDecoder(args)
        model = PCAE(encoder, decoder, args)

        logger.watch(encoder._encoder, log='all', log_freq=args.log.frequency)
        logger.watch(decoder, log='all', log_freq=args.log.frequency)
    elif args.model == 'ocae':
        from scae.modules.object_capsule_ae import SetTransformer, ImageCapsule
        from scae.models.ocae import OCAE

        encoder = SetTransformer()
        decoder = ImageCapsule()
        model = OCAE(encoder, decoder, args)

        #  TODO: after ccae
    else:
        raise NotImplementedError()

    # Execute Experiment
    lr_logger = cb.LearningRateMonitor(logging_interval='step')
    best_checkpointer = cb.ModelCheckpoint(save_top_k=1,
                                           monitor='val_rec_ll',
                                           filepath=logger.experiment.dir)
    last_checkpointer = cb.ModelCheckpoint(save_last=True,
                                           filepath=logger.experiment.dir)
    trainer = pl.Trainer(
        max_epochs=args.num_epochs,
        logger=logger,
        callbacks=[lr_logger, best_checkpointer, last_checkpointer])
    trainer.fit(model, train_dataloader, val_dataloader)
 def setup(self, stage=None):
     self.usps_test = self.colorize_dataset(
         USPS(self.data_dir, train=False, download=True))
     usps_full = self.colorize_dataset(
         USPS(self.data_dir, train=True, download=True))
     self.usps_train, self.usps_val = random_split(usps_full, [6000, 1291])
Beispiel #7
0
def select_dataset(dataset_name, input_dim=2, n_samples=10000):
    """
    :params n_samples: number of points returned. If 0, all datapoints will be returned. For artificial data, it will throw an error.
    """
    if dataset_name == 'fmnist':
        f_mnist = FashionMNIST(root="./datasets", download=True)
        data = f_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = f_mnist.targets.numpy()
    elif dataset_name == 'emnist':
        f_mnist = EMNIST(root="./datasets", download=True, split='byclass')
        data = f_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = f_mnist.targets.numpy()
    elif dataset_name == 'kmnist':
        f_mnist = KMNIST(root="./datasets", download=True)
        data = f_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = f_mnist.targets.numpy()
    elif dataset_name == 'usps':
        f_mnist = USPS(root="./datasets", download=True)
        data = f_mnist.data
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = np.float32(f_mnist.targets)
    elif dataset_name == 'news':
        newsgroups_train = fetch_20newsgroups(data_home='./datasets', subset='train',
                                              remove=('headers', 'footers', 'quotes'))
        vectorizer = TfidfVectorizer()
        vec_data = vectorizer.fit_transform(newsgroups_train.data).toarray()
        vec_data = np.float32(vec_data)
        labels = newsgroups_train.target
        labels = np.float32(labels)
    elif dataset_name == 'cover_type':
        file_name = file_path + "/datasets/covtype.data"
        train_data = np.array(pd.read_csv(file_name, sep=','))
        vec_data = np.float32(train_data[:, :-1])
        labels = np.float32(train_data[:, -1])
    elif dataset_name == 'char':
        digits = datasets.load_digits()
        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))
        vec_data = np.float32(data)
        labels = digits.target

    elif dataset_name == 'charx':
        file_name = file_path + "/datasets/char_x.npy"
        data = np.load(file_name, allow_pickle=True)
        vec_data, labels = data[0], data[1]

    elif dataset_name == 'kdd_cup':
        cover_train = fetch_kddcup99(data_home='./datasets', download_if_missing=True)
        vec_data = cover_train.data
        string_labels = cover_train.target
        vec_data, labels = feature_tranformers.vectorizer_kdd(data=vec_data, labels=string_labels)
    elif dataset_name == 'aggregation':
        file_name = file_path + "/2d_data/Aggregation.csv"
        a = np.array(pd.read_csv(file_name, sep=';'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'compound':
        file_name = file_path + "/2d_data/Compound.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'd31':
        file_name = file_path + "/2d_data/D31.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'flame':
        file_name = file_path + "/2d_data/flame.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'path_based':
        file_name = file_path + "/2d_data/pathbased.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'r15':
        file_name = file_path + "/2d_data/R15.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'spiral':
        file_name = file_path + "/2d_data/spiral.txt"
        a = np.array(pd.read_csv(file_name, sep='\t'))
        vec_data = a[:, 0:2]
        labels = a[:, 2]
    elif dataset_name == 'birch1':
        file_name = file_path + "/2d_data/birch1.txt"
        a = np.array(pd.read_csv(file_name, delimiter=r"\s+"))
        vec_data = a[:, :]
        labels = np.ones((vec_data.shape[0]))
    elif dataset_name == 'birch2':
        file_name = file_path + "/2d_data/birch2.txt"
        a = np.array(pd.read_csv(file_name, delimiter=r"\s+"))
        vec_data = a[:, :]
        labels = np.ones((vec_data.shape[0]))
    elif dataset_name == 'birch3':
        file_name = file_path + "/2d_data/birch3.txt"
        a = np.array(pd.read_csv(file_name, delimiter=r"\s+"))
        vec_data = a[:, :]
        labels = np.ones((vec_data.shape[0]))
    elif dataset_name == 'worms':
        file_name = file_path + "/2d_data/worms/worms_2d.txt"
        a = np.array(pd.read_csv(file_name, sep=' '))
        vec_data = a[:, :]
        labels = np.ones((vec_data.shape[0]))
    elif dataset_name == 't48k':
        file_name = file_path + "/2d_data/t4.8k.txt"
        a = np.array(pd.read_csv(file_name, sep=' '))
        vec_data = a[1:, :]
        labels = np.ones((vec_data.shape[0]))
    elif dataset_name == 'moons':
        data, labels = make_moons(n_samples=5000)
        vec_data = np.float32(data)
        labels = np.float32(labels)
    elif dataset_name == 'circles':
        data, labels = make_circles(n_samples=5000)
        vec_data = np.float32(data)
        labels = np.float32(labels)
    elif dataset_name == 'blobs':
        data, labels = make_blobs(n_samples=n_samples, centers=3)
        vec_data = np.float32(data)
        labels = np.float32(labels)
    elif dataset_name == 'gmm':
        mean_1 = np.zeros(input_dim)
        mean_2 = 100 * np.ones(input_dim)
        cov = np.eye(input_dim)
        data_1 = np.random.multivariate_normal(mean_1, cov, int(n_samples / 2))
        labels_1 = np.ones(int(n_samples / 2))
        labels_2 = 2 * np.ones(int(n_samples / 2))
        data_2 = np.random.multivariate_normal(mean_2, cov, int(n_samples / 2))
        vec_data = np.concatenate([data_1, data_2], axis=0)
        labels = np.concatenate([labels_1, labels_2], axis=0)
    elif dataset_name == 'uniform':
        vec_data = np.random.uniform(0, 1, size=(n_samples, input_dim)) * 10
        labels = np.ones(n_samples)
    elif dataset_name == 'mnist_pc':
        d_mnist = MNIST(root="./datasets", download=True)
        mnist = d_mnist.data.numpy()
        data = np.float32(np.reshape(mnist, (mnist.shape[0], -1)))
        pca_data = PCA(n_components=input_dim).fit_transform(data)
        n_indices = np.random.randint(pca_data.shape[0], size=n_samples)
        vec_data = pca_data[n_indices]
        labels = d_mnist.targets.numpy()[n_indices]
    elif dataset_name == 'usps_pc':
        d_mnist = USPS(root="./datasets", download=True)
        mnist = d_mnist.data
        data = np.float32(np.reshape(mnist, (mnist.shape[0], -1)))
        pca_data = PCA(n_components=input_dim).fit_transform(data)
        n_indices = np.random.randint(pca_data.shape[0], size=n_samples)
        vec_data = pca_data[n_indices]
        labels = np.float32(d_mnist.targets)
    elif dataset_name == 'char_pc':
        digits = datasets.load_digits()
        n_samples = len(digits.images)
        data = digits.images.reshape((n_samples, -1))
        data = np.float32(data)
        targets = digits.target
        pca_data = PCA(n_components=input_dim).fit_transform(data)
        n_indices = np.random.randint(pca_data.shape[0], size=n_samples)
        vec_data = pca_data[n_indices]
        labels = targets
    else:
        d_mnist = MNIST(root="./datasets", download=True)
        data = d_mnist.data.numpy()
        vec_data = np.reshape(data, (data.shape[0], -1))
        vec_data = np.float32(vec_data)
        labels = d_mnist.targets.numpy()

    if 0 < n_samples < vec_data.shape[0]:
        rand_indices = np.random.choice(vec_data.shape[0], size=(n_samples,), replace=False)
        return vec_data[rand_indices], labels[rand_indices]
    else:
        return vec_data, labels
Beispiel #8
0
                    if use_y_to_verify_performance:
                        plot_3d(transformed, y_for_verification, show=False)

                    plt.show()


if __name__ == "__main__":
    from torchvision.datasets import MNIST, USPS, FashionMNIST, CIFAR10
    from torchtext.datasets import AG_NEWS

    n = None
    # semisupervised_proportion = .2

    e = DEN(n_components=2, internal_dim=128)

    USPS_data_train = USPS("./", train=True, download=True)
    USPS_data_test = USPS("./", train=False, download=True)
    USPS_data = ConcatDataset([USPS_data_test, USPS_data_train])
    X, y = zip(*USPS_data)

    y_numpy = np.array(y[:n])
    X_numpy = np.array(
        [np.asarray(X[i]) for i in range(n if n is not None else len(X))])
    X = torch.Tensor(X_numpy).unsqueeze(1)

    # which = np.random.choice(len(y_numpy), int((1-semisupervised_proportion)*len(y_numpy)), replace = False)
    # y_for_verification = copy.deepcopy(y_numpy)
    # y_numpy[which] = -1

    # news_train, news_test = AG_NEWS('./', ngrams = 1)
    # X, y = zip(*([item[1], item[0]] for item in news_test))
Beispiel #9
0
def get_dataset(args, config):
    if config.data.dataset == 'CIFAR10':
        if (config.data.random_flip):
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.RandomHorizontalFlip(p=0.5),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

        else:
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-32px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(32),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-8px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(8),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif config.data.dataset == 'LSUN':
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    elif config.data.dataset == "MNIST":
        if config.data.random_flip:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        else:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor()
                             ]))
    elif config.data.dataset == "USPS":
        if config.data.random_flip:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        else:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        test_dataset = USPS(root=os.path.join('datasets', 'USPS'),
                            train=False,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
    elif config.data.dataset == "USPS-Pad":
        if config.data.random_flip:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        else:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        test_dataset = USPS(
            root=os.path.join('datasets', 'USPS'),
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(20),  # resize and pad like MNIST
                transforms.Pad(4),
                transforms.Resize(config.data.image_size),
                transforms.ToTensor()
            ]))
    elif (config.data.dataset.upper() == "GAUSSIAN"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        if (config.data.isotropic):
            dim = config.data.dim
            rank = config.data.rank
            cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)]))
            mean = np.zeros((dim, ))
        else:
            cov = np.array(config.data.cov)
            mean = np.array(config.data.mean)

        shape = config.data.dataset.shape if hasattr(config.data.dataset,
                                                     "shape") else None

        dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape)
        test_dataset = Gaussian(device=args.device,
                                cov=cov,
                                mean=mean,
                                shape=shape)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        cov = np.load(config.data.cov_path)
        mean = np.load(config.data.mean_path)
        dataset = Gaussian(device=args.device, cov=cov, mean=mean)
        test_dataset = Gaussian(device=args.device, cov=cov, mean=mean)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"):
        # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality
        #   of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow.
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        shape = config.data.shape if hasattr(config.data, "shape") else None
        dataset = Gaussian(device=args.device,
                           mean=None,
                           cov=None,
                           shape=shape,
                           iid_unit=True)
        test_dataset = Gaussian(device=args.device,
                                mean=None,
                                cov=None,
                                shape=shape,
                                iid_unit=True)

    return dataset, test_dataset