transforms = Compose([
            Resize(96),
            CenterCrop(96),
            ToTensor()
        ])
        dataset = LSUN('exp/datasets/lsun', ['church_outdoor_train'], transform=transforms)

    elif args.dataset == 'tower' or args.dataset == 'bedroom':
        transforms = Compose([
            Resize(128),
            CenterCrop(128),
            ToTensor()
        ])
        dataset = LSUN('exp/datasets/lsun', ['{}_train'.format(args.dataset)], transform=transforms)

    elif args.dataset == 'celeba':
        transforms = Compose([
            CenterCrop(140),
            Resize(64),
            ToTensor(),
        ])
        dataset = CelebA('exp/datasets/celeba', split='train', transform=transforms)

    elif args.dataset == 'cifar10':
        dataset = CIFAR10('exp/datasets/cifar10', train=True, transform=ToTensor())
    elif args.dataset == 'ffhq':
        dataset = FFHQ(path='exp/datasets/FFHQ', transform=ToTensor(), resolution=256)

    dataloader = DataLoader(dataset, batch_size=128, drop_last=False)
    get_nearest_neighbors(dataloader, args.path, args.i, args.n_samples, args.k, torch.cuda.is_available())
示例#2
0
def get_dataset(d_config, data_folder):
    cmp = lambda x: transforms.Compose([*x])

    if d_config.dataset == 'CIFAR10':

        train_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(1, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'CIFAR10')
        dataset = CIFAR10(path,
                          train=True,
                          download=True,
                          transform=cmp(train_transform))
        test_dataset = CIFAR10(path,
                               train=False,
                               download=True,
                               transform=cmp(test_transform))

    elif d_config.dataset == 'CELEBA':

        train_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'celeba')
        dataset = CelebA(path,
                         split='train',
                         transform=cmp(train_transform),
                         download=True)
        test_dataset = CelebA(path,
                              split='test',
                              transform=cmp(test_transform),
                              download=True)

    elif d_config.dataset == 'Stacked_MNIST':

        dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                  'stackedmnist_train'),
                                load=False,
                                source_root=data_folder,
                                train=True)
        test_dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                       'stackedmnist_test'),
                                     load=False,
                                     source_root=data_folder,
                                     train=False)

    elif d_config.dataset == 'LSUN':

        ims = d_config.image_size
        train_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = data_folder
        dataset = LSUN(path,
                       classes=[d_config.category + "_train"],
                       transform=cmp(train_transform))
        test_dataset = LSUN(path,
                            classes=[d_config.category + "_val"],
                            transform=cmp(test_transform))

    elif d_config.dataset == "FFHQ":

        train_transform = [transforms.ToTensor()]
        test_transform = [transforms.ToTensor()]
        if d_config.random_flip:
            train_transform.insert(0, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'FFHQ')
        dataset = FFHQ(path,
                       transform=train_transform,
                       resolution=d_config.image_size)
        test_dataset = FFHQ(path,
                            transform=test_transform,
                            resolution=d_config.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        dataset = Subset(dataset, train_indices)
        test_dataset = Subset(test_dataset, test_indices)

    else:
        raise ValueError("Dataset [" + d_config.dataset + "] not configured.")

    return dataset, test_dataset
示例#3
0
def get_dataset(args, config):
    if config.data.dataset == 'CIFAR10':
        if (config.data.random_flip):
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.RandomHorizontalFlip(p=0.5),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

        else:
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-32px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(32),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-8px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(8),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif config.data.dataset == 'LSUN':
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    elif config.data.dataset == "MNIST":
        if config.data.random_flip:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        else:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor()
                             ]))
    elif config.data.dataset == "USPS":
        if config.data.random_flip:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        else:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        test_dataset = USPS(root=os.path.join('datasets', 'USPS'),
                            train=False,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
    elif config.data.dataset == "USPS-Pad":
        if config.data.random_flip:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        else:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        test_dataset = USPS(
            root=os.path.join('datasets', 'USPS'),
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(20),  # resize and pad like MNIST
                transforms.Pad(4),
                transforms.Resize(config.data.image_size),
                transforms.ToTensor()
            ]))
    elif (config.data.dataset.upper() == "GAUSSIAN"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        if (config.data.isotropic):
            dim = config.data.dim
            rank = config.data.rank
            cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)]))
            mean = np.zeros((dim, ))
        else:
            cov = np.array(config.data.cov)
            mean = np.array(config.data.mean)

        shape = config.data.dataset.shape if hasattr(config.data.dataset,
                                                     "shape") else None

        dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape)
        test_dataset = Gaussian(device=args.device,
                                cov=cov,
                                mean=mean,
                                shape=shape)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        cov = np.load(config.data.cov_path)
        mean = np.load(config.data.mean_path)
        dataset = Gaussian(device=args.device, cov=cov, mean=mean)
        test_dataset = Gaussian(device=args.device, cov=cov, mean=mean)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"):
        # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality
        #   of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow.
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        shape = config.data.shape if hasattr(config.data, "shape") else None
        dataset = Gaussian(device=args.device,
                           mean=None,
                           cov=None,
                           shape=shape,
                           iid_unit=True)
        test_dataset = Gaussian(device=args.device,
                                mean=None,
                                cov=None,
                                shape=shape,
                                iid_unit=True)

    return dataset, test_dataset
示例#4
0
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.Transpose(), lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0
        ])
    else:
        tran_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.RandomHorizontalFlip(prob=0.5),
            transforms.Transpose(),
            lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0,
        ])
        test_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.Transpose(), lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0
        ])

    if config.data.dataset == "CIFAR10":
        dataset = Cifar10(
            # os.path.join(args.exp, "datasets", "cifar10"),
            mode="train",
            download=True,
            transform=tran_transform,
        )
        test_dataset = Cifar10(
            # os.path.join(args.exp, "datasets", "cifar10_test"),
            mode="test",
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        if config.data.random_flip:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose([
                    Crop(x1, x2, y1, y2),
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.RandomHorizontalFlip(),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
                download=True,
            )
        else:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose([
                    Crop(x1, x2, y1, y2),
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
                download=True,
            )

        test_dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="test",
            transform=transforms.Compose([
                Crop(x1, x2, y1, y2),
                transforms.Resize([config.data.image_size] * 2),
                transforms.Transpose(),
                lambda x: x
                if x.dtype != np.uint8 else x.astype('float32') / 255.0,
            ]),
            download=True,
        )

    elif config.data.dataset == "LSUN":
        train_folder = "{}_train".format(config.data.category)
        val_folder = "{}_val".format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose([
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.CenterCrop((config.data.image_size, ) * 2),
                    transforms.RandomHorizontalFlip(prob=0.5),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
            )
        else:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose([
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.CenterCrop((config.data.image_size, ) * 2),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
            )

        test_dataset = LSUN(
            root=os.path.join(args.exp, "datasets", "lsun"),
            classes=[val_folder],
            transform=transforms.Compose([
                transforms.Resize([config.data.image_size] * 2),
                transforms.CenterCrop((config.data.image_size, ) * 2),
                transforms.Transpose(),
                lambda x: x
                if x.dtype != np.uint8 else x.astype('float32') / 255.0,
            ]),
        )

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(prob=0.5),
                    transforms.Transpose(), lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0
                ]),
                resolution=config.data.image_size,
            )
        else:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose(
                    transforms.Transpose(), lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0),
                resolution=config.data.image_size,
            )

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = (
            indices[:int(num_items * 0.9)],
            indices[int(num_items * 0.9):],
        )
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)
    else:
        dataset, test_dataset = None, None

    return dataset, test_dataset
示例#5
0
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])
    else:
        tran_transform = transforms.Compose([
            transforms.Resize(config.data.image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])

    if config.data.dataset == 'CIFAR10':
        dataset = CIFAR10(os.path.join(args.exp, 'datasets', 'cifar10'),
                          train=True,
                          download=True,
                          transform=tran_transform)
        test_dataset = CIFAR10(os.path.join(args.exp, 'datasets',
                                            'cifar10_test'),
                               train=False,
                               download=True,
                               transform=test_transform)

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=False)
        else:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=False)

        test_dataset = CelebA(root=os.path.join(args.exp, 'datasets',
                                                'celeba_test'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=False)

    elif config.data.dataset == 'LSUN':
        # import ipdb; ipdb.set_trace()
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    return dataset, test_dataset