Ejemplo n.º 1
0
def load_ds(flags=None):
    """Constructs the dataset that is set in flags.
    
    Args:
        flags: A FLAGS object with properties. If it's not set, use the global flags.
        
    Returns:
        The Dataset object that is set in the flags.
    """

    if flags is None:  # Load the default flags.
        flags = tf.app.flags.FLAGS

    if flags.dataset.lower() == 'mnist':
        ds = Mnist()
    elif flags.dataset.lower() == 'f-mnist':
        ds = FMnist()
    elif flags.dataset.lower() == 'celeba':
        if hasattr(flags, 'attribute'):
            ds = CelebA(resize_size=flags.output_height,
                        attribute=flags.attribute)
        else:
            ds = CelebA(resize_size=flags.output_height)
    else:
        raise ValueError('[!] Dataset {} is not supported.'.format(
            flags.dataset.lower()))
    return ds
Ejemplo n.º 2
0
def get_dataset(args, config):
    tran_transform = transforms.Compose([
        transforms.Resize(config.data.image_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose(
        [transforms.Resize(config.data.image_size),
         transforms.ToTensor()])

    if config.data.dataset == "CIFAR10":
        dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10"),
            train=True,
            download=True,
            transform=tran_transform,
        )
        test_dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10_test"),
            train=False,
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="train",
            transform=transforms.Compose([
                Crop(x1, x2, y1, y2),
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]),
            download=True,
        )

        test_dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="test",
            transform=transforms.Compose([
                Crop(x1, x2, y1, y2),
                transforms.Resize(config.data.image_size),
                transforms.ToTensor(),
            ]),
            download=True,
        )

    else:
        dataset, test_dataset = None, None

    return dataset, test_dataset
Ejemplo n.º 3
0
def create_generator(dataset_name,
                     split,
                     batch_size,
                     randomize,
                     attribute=None):
    """Creates a batch generator for the dataset.

    Args:
        dataset_name: `str`. The name of the dataset.
        split: `str`. The split of data. It can be `train`, `val`, or `test`.
        batch_size: An integer. The batch size.
        randomize: `bool`. Whether to randomize the order of images before
            batching.
        attribute (optional): For cele

    Returns:
        image_batch: A Python generator for the images.
        label_batch: A Python generator for the labels.
    """
    flags = tf.app.flags.FLAGS

    if dataset_name.lower() == 'mnist':
        ds = Mnist()
    elif dataset_name.lower() == 'f-mnist':
        ds = FMnist()
    elif dataset_name.lower() == 'cifar-10':
        ds = Cifar10()
    elif dataset_name.lower() == 'celeba':
        ds = CelebA(attribute=attribute)
    else:
        raise ValueError("Dataset {} is not supported.".format(dataset_name))

    ds.load(split=split, randomize=randomize)

    def get_gen():
        for i in range(0, len(ds) - batch_size, batch_size):
            image_batch, label_batch = ds.images[
                                       i:i + batch_size], \
                                       ds.labels[i:i + batch_size]
            yield image_batch, label_batch

    return get_gen
Ejemplo n.º 4
0
                                    transform=transforms.ToTensor(),
                                    download=True)
    vae = VAE28(num_channels=1, zdim=10)
elif args.dataset == 'mnist':
    dataset = datasets.MNIST(root='./data/mnist',
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True)
    vae = VAE28(num_channels=1, zdim=10)
elif args.dataset == 'dsprites':
    dataset = DSprites(root='./data/dsprites',
                       transform=transforms.ToTensor(),
                       download=True)
    vae = VAE64(num_channels=1, zdim=10)
elif args.dataset == 'celeba':
    dataset = CelebA(root='./data/celeba', transform=transforms.ToTensor())
    vae = VAE64(num_channels=3, zdim=32)
    args.obs = 'normal'
else:
    raise ValueError(
        'The `dataset` argument must be fashion-mnist, mnist, dsprites or celeba'
    )

data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=args.num_images,
                                          shuffle=True)

fixed_x, _ = next(iter(data_loader))
fixed_x = to_var(fixed_x, args.cuda, volatile=True)

if args.save_file is not None:
Ejemplo n.º 5
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                           'cifar10'),
                              train=True,
                              download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                                'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=test_transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'),
                            train=True,
                            download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets',
                                              'mnist_test'),
                                 train=False,
                                 download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]),
                    download=False)
            else:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.ToTensor(),
                    ]),
                    download=False)

            test_dataset = CelebA(
                root=os.path.join(self.args.run, 'datasets', 'celeba_test'),
                split='test',
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                ]),
                download=False)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'),
                           split='train',
                           download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets',
                                             'svhn_test'),
                                split='test',
                                download=True,
                                transform=test_transform)

        dataloader = DataLoader(dataset,
                                batch_size=self.config.training.batch_size,
                                shuffle=True,
                                num_workers=4)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.config.training.batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 drop_last=True)

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(
                np.linspace(np.log(self.config.model.sigma_begin),
                            np.log(self.config.model.sigma_end),
                            self.config.model.num_classes))).float().to(
                                self.config.device)

        time_record = []
        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0,
                                       len(sigmas), (X.shape[0], ),
                                       device=X.device)
                if self.config.training.algo == 'dsm':
                    t = time.time()
                    loss = anneal_dsm_score_estimation(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'dsm_tracetrick':
                    t = time.time()
                    loss = anneal_dsm_score_estimation_TraceTrick(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    t = time.time()
                    loss = anneal_sliced_score_estimation_vr(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet':
                    t = time.time()
                    loss = anneal_ESM_scorenet(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet_VR':
                    t = time.time()
                    loss = anneal_ESM_scorenet_VR(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t = time.time() - t
                time_record.append(t)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    logging.info(
                        "step: {}, loss: {}, time per step: {:.3f} +- {:.3f} ms"
                        .format(step, loss.item(),
                                np.mean(time_record) * 1e3,
                                np.std(time_record) * 1e3))

                    # if step % 2000 == 0:
                    #     score.eval()
                    #     try:
                    #         test_X, test_y = next(test_iter)
                    #     except StopIteration:
                    #         test_iter = iter(test_loader)
                    #         test_X, test_y = next(test_iter)

                    #     test_X = test_X.to(self.config.device)
                    #     test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.

                    #     if self.config.data.logit_transform:
                    #         test_X = self.logit_transform(test_X)

                    #     test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    #     #if self.config.training.algo == 'dsm':
                    #     with torch.no_grad():
                    #         test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                    #                                                         self.config.training.anneal_power)

                    #     tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)
                    #     logging.info("step: {}, test dsm loss: {}".format(step, test_dsm_loss.item()))

                    # elif self.config.training.algo == 'ssm':
                    #     test_ssm_loss = anneal_sliced_score_estimation_vr(score, test_X, test_labels, sigmas,
                    #                                          n_particles=self.config.training.n_particles)

                    #     tb_logger.add_scalar('test_ssm_loss', test_ssm_loss, global_step=step)
                    #     logging.info("step: {}, test ssm loss: {}".format(step, test_ssm_loss.item()))

                if step >= 140000 and step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(
                        states,
                        os.path.join(self.args.log,
                                     'checkpoint_{}.pth'.format(step)))
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint.pth'))
Ejemplo n.º 6
0
def get_dataset(d_config, data_folder):
    cmp = lambda x: transforms.Compose([*x])

    if d_config.dataset == 'CIFAR10':

        train_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(1, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'CIFAR10')
        dataset = CIFAR10(path,
                          train=True,
                          download=True,
                          transform=cmp(train_transform))
        test_dataset = CIFAR10(path,
                               train=False,
                               download=True,
                               transform=cmp(test_transform))

    elif d_config.dataset == 'CELEBA':

        train_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'celeba')
        dataset = CelebA(path,
                         split='train',
                         transform=cmp(train_transform),
                         download=True)
        test_dataset = CelebA(path,
                              split='test',
                              transform=cmp(test_transform),
                              download=True)

    elif d_config.dataset == 'Stacked_MNIST':

        dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                  'stackedmnist_train'),
                                load=False,
                                source_root=data_folder,
                                train=True)
        test_dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                       'stackedmnist_test'),
                                     load=False,
                                     source_root=data_folder,
                                     train=False)

    elif d_config.dataset == 'LSUN':

        ims = d_config.image_size
        train_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = data_folder
        dataset = LSUN(path,
                       classes=[d_config.category + "_train"],
                       transform=cmp(train_transform))
        test_dataset = LSUN(path,
                            classes=[d_config.category + "_val"],
                            transform=cmp(test_transform))

    elif d_config.dataset == "FFHQ":

        train_transform = [transforms.ToTensor()]
        test_transform = [transforms.ToTensor()]
        if d_config.random_flip:
            train_transform.insert(0, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'FFHQ')
        dataset = FFHQ(path,
                       transform=train_transform,
                       resolution=d_config.image_size)
        test_dataset = FFHQ(path,
                            transform=test_transform,
                            resolution=d_config.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        dataset = Subset(dataset, train_indices)
        test_dataset = Subset(test_dataset, test_indices)

    else:
        raise ValueError("Dataset [" + d_config.dataset + "] not configured.")

    return dataset, test_dataset
Ejemplo n.º 7
0
def main(args):
    ### config
    global noise_multiplier
    dataset = args.dataset
    num_discriminators = args.num_discriminators
    noise_multiplier = args.noise_multiplier
    z_dim = args.z_dim
    if dataset == 'celeba':
        z_dim = 100
    model_dim = args.model_dim
    batchsize = args.batchsize
    L_gp = args.L_gp
    L_epsilon = args.L_epsilon
    critic_iters = args.critic_iters
    latent_type = args.latent_type
    load_dir = args.load_dir
    save_dir = args.save_dir
    if_dp = (args.noise_multiplier > 0.)
    gen_arch = args.gen_arch
    num_gpus = args.num_gpus

    ### CUDA
    use_cuda = torch.cuda.is_available()
    devices = [
        torch.device("cuda:%d" % i if use_cuda else "cpu")
        for i in range(num_gpus)
    ]
    device0 = devices[0]
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    ### Random seed
    if args.random_seed == 1:
        args.random_seed = np.random.randint(10000, size=1)[0]
    print('random_seed: {}'.format(args.random_seed))
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    ### Fix noise for visualization
    if latent_type == 'normal':
        fix_noise = torch.randn(10, z_dim)
    elif latent_type == 'bernoulli':
        p = 0.5
        bernoulli = torch.distributions.Bernoulli(torch.tensor([p]))
        fix_noise = bernoulli.sample((10, z_dim)).view(10, z_dim)
    else:
        raise NotImplementedError

    ### Set up models
    print('gen_arch:' + gen_arch)
    if dataset == 'mnist':
        netG = GeneratorDCGAN(z_dim=z_dim, model_dim=model_dim, num_classes=10)
    elif dataset == 'cifar_10':
        netG = GeneratorDCGAN_cifar_ch3(z_dim=z_dim,
                                        model_dim=model_dim,
                                        num_classes=10)
        netG.apply(weights_init)
    elif dataset == 'celeba':
        ngpu = 1
        netG = Generator_celeba(ngpu)

        # Handle multi-gpu if desired
        if (device0.type == 'cuda') and (ngpu > 1):
            netG = nn.DataParallel(netG, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.02.
        netG.apply(weights_init)

    netGS = copy.deepcopy(netG)
    netD_list = []
    for i in range(num_discriminators):
        if dataset == 'mnist':
            netD = DiscriminatorDCGAN()
        elif dataset == 'cifar_10':
            netD = DiscriminatorDCGAN_cifar_ch3()
            #netD.apply(weights_init)
        elif dataset == 'celeba':
            ngpu = 1
            netD = Discriminator_celeba(ngpu)

            # Handle multi-gpu if desired
            if (device0.type == 'cuda') and (ngpu > 1):
                netD = nn.DataParallel(netD, list(range(ngpu)))

            # Apply the weights_init function to randomly initialize all weights
            #  to mean=0, stdev=0.2.
            #netD.apply(weights_init)
        netD_list.append(netD)

    ### Load pre-trained discriminators
    print("load pre-training...")
    if load_dir is not None:
        for netD_id in range(num_discriminators):
            print('Load NetD ', str(netD_id))
            network_path = os.path.join(load_dir, 'netD_%d' % netD_id,
                                        'netD.pth')
            netD = netD_list[netD_id]
            netD.load_state_dict(torch.load(network_path))

    netG = netG.to(device0)
    for netD_id, netD in enumerate(netD_list):
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD.to(device)

    ### Set up optimizers
    optimizerD_list = []
    for i in range(num_discriminators):
        netD = netD_list[i]
        optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.999))
        optimizerD_list.append(optimizerD)
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))

    ### Data loaders
    if dataset == 'mnist':
        transform_train = transforms.Compose([
            transforms.CenterCrop((28, 28)),
            transforms.ToTensor(),
        ])
    elif dataset == 'cifar_10':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    elif dataset == 'celeba':
        transform_train = transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    if dataset == 'mnist':
        IMG_DIM = 784
        NUM_CLASSES = 10
        dataloader = datasets.MNIST
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'MNIST'),
                              train=True,
                              download=True,
                              transform=transform_train)
    elif dataset == 'cifar_10':
        IMG_DIM = 3072
        NUM_CLASSES = 10
        dataloader = datasets.CIFAR10
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR10'),
                              train=True,
                              download=True,
                              transform=transform_train)
    elif dataset == 'celeba':
        IMG_DIM = 64 * 64 * 3
        NUM_CLASSES = 2
        trainset = CelebA(
            root=os.path.join('../data'),
            split='train',  #'/work/u5366584/exp/datasets/celeba/'
            transform=transform_train,
            download=False,
            custom_subset=True)
    else:
        raise NotImplementedError

    ###fix sub-training set (fix to 10000 training samples)
    if args.update_train_dataset:
        if dataset == 'mnist':
            indices_full = np.arange(60000)
        elif dataset == 'cifar_10':
            indices_full = np.arange(50000)
        elif dataset == 'celeba':
            indices_full = np.arange(len(trainset))
        np.random.shuffle(indices_full)
        indices_slice = indices_full[:20000]
        np.savetxt('index_20k.txt', indices_slice, fmt='%i')
    #indices = np.loadtxt('index_20k.txt', dtype=np.int_)
    #trainset = torch.utils.data.Subset(trainset, indices)

    print('creat indices file')
    indices_full = np.arange(len(trainset))
    np.random.shuffle(indices_full)
    #indices_full.dump(os.path.join(save_dir, 'indices.npy'))
    trainset_size = int(len(trainset) / num_discriminators)
    print('Size of the dataset: ', trainset_size)

    input_pipelines = []
    for i in range(num_discriminators):
        start = i * trainset_size
        end = (i + 1) * trainset_size
        indices = indices_full[start:end]
        trainloader = DataLoader(trainset,
                                 batch_size=args.batchsize,
                                 drop_last=False,
                                 num_workers=args.num_workers,
                                 sampler=SubsetRandomSampler(indices))
        #input_data = inf_train_gen(trainloader)
        input_pipelines.append(trainloader)

    if if_dp:
        ### Register hook
        global dynamic_hook_function
        for netD in netD_list:
            netD.conv1.register_backward_hook(master_hook_adder)

    prg_bar = tqdm(range(args.iterations + 1))
    for iters in prg_bar:
        #########################
        ### Update D network
        #########################
        netD_id = np.random.randint(num_discriminators, size=1)[0]
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD = netD_list[netD_id]
        optimizerD = optimizerD_list[netD_id]
        input_data = input_pipelines[netD_id]

        for p in netD.parameters():
            p.requires_grad = True

        for iter_d in range(critic_iters):
            real_data, real_y = next(iter(input_data))
            real_data = real_data.view(-1, IMG_DIM)
            real_data = real_data.to(device)
            real_y = real_y.to(device)

            ###########################
            if dataset == 'celeba':
                gender = 20
            real_y = real_y[:, gender]
            ##################################################

            real_data_v = autograd.Variable(real_data)

            ### train with real
            dynamic_hook_function = dummy_hook
            netD.zero_grad()
            D_real_score = netD(real_data_v, real_y)
            D_real = -D_real_score.mean()

            ### train with fake
            batchsize = real_data.shape[0]
            if latent_type == 'normal':
                noise = torch.randn(batchsize, z_dim).to(device0)
            elif latent_type == 'bernoulli':
                noise = bernoulli.sample(
                    (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
            else:
                raise NotImplementedError
            noisev = autograd.Variable(noise)
            fake = autograd.Variable(netG(noisev, real_y.to(device0)).data)
            inputv = fake.to(device)
            D_fake = netD(inputv, real_y.to(device))
            D_fake = D_fake.mean()

            ### train with gradient penalty
            gradient_penalty = netD.calc_gradient_penalty(
                real_data_v.data, fake.data, real_y, L_gp, device)
            D_cost = D_fake + D_real + gradient_penalty

            ### train with epsilon penalty
            logit_cost = L_epsilon * torch.pow(D_real_score, 2).mean()
            D_cost += logit_cost

            ### update
            D_cost.backward()
            Wasserstein_D = -D_real - D_fake
            optimizerD.step()

        del real_data, real_y, fake, noise, inputv, D_real, D_fake, logit_cost, gradient_penalty
        torch.cuda.empty_cache()

        for iter_d in range(3):
            ############################
            # Update G network
            ###########################
            if if_dp:
                ### Sanitize the gradients passed to the Generator
                dynamic_hook_function = dp_conv_hook
            else:
                ### Only modify the gradient norm, without adding noise
                dynamic_hook_function = modify_gradnorm_conv_hook

            for p in netD.parameters():
                p.requires_grad = False
            netG.zero_grad()

            ### train with sanitized discriminator output
            if latent_type == 'normal':
                noise = torch.randn(batchsize, z_dim).to(device0)
            elif latent_type == 'bernoulli':
                noise = bernoulli.sample(
                    (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
            else:
                raise NotImplementedError
            label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0)
            noisev = autograd.Variable(noise)
            fake = netG(noisev, label)
            #summary(netG, input_data=[noisev,label])
            fake = fake.to(device)
            label = label.to(device)
            G = netD(fake, label)
            G = -G.mean()

            ### update
            G.backward()
            G_cost = G
            optimizerG.step()

        ### update the exponential moving average
        exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters)

        ############################
        ### Results visualization
        ############################
        prg_bar.set_description(
            'iter:{}, G_cost:{:.2f}, D_cost:{:.2f}, Wasserstein:{:.2f}'.format(
                iters,
                G_cost.cpu().data,
                D_cost.cpu().data,
                Wasserstein_D.cpu().data))
        if iters % args.vis_step == 0:
            if dataset == 'mnist':
                generate_image_mnist(iters, netGS, fix_noise, save_dir,
                                     device0)
            elif dataset == 'cifar_10':
                generate_image_cifar10_ch3(str(iters + 0), netGS, fix_noise,
                                           save_dir, device0)
            elif dataset == 'celeba':
                generate_image_celeba(str(iters + 0), netGS, fix_noise,
                                      save_dir, device0)

        if iters % args.save_step == 0:
            ### save model
            torch.save(netGS.state_dict(),
                       os.path.join(save_dir, 'netGS_%s.pth' % str(iters + 0)))
            torch.save(netD.state_dict(),
                       os.path.join(save_dir, 'netD_%s.pth' % str(iters + 0)))

        del label, fake, noisev, noise, G, G_cost, D_cost
        torch.cuda.empty_cache()
Ejemplo n.º 8
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab the first two samples from MNIST
        dataset = CelebA(os.path.join(self.args.run, 'datasets', 'celeba'),
                         split='test',
                         download=True)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
        input_image = np.array(dataset[0][0]).astype(np.float).transpose(
            2, 0, 1)

        # input_image = cv2.imread("/projects/grail/vjayaram/source_separation/ncsn/run/datasets/celeba/celeba/img_align_celeba/012690.jpg")

        # input_image = cv2.resize(input_image, (32, 32))[:,:,::-1].transpose(2, 0, 1)
        input_image = input_image / 255.
        noise = np.random.randn(*input_image.shape) / 10
        cv2.imwrite("input_image.png", (input_image * 255).astype(
            np.uint8).transpose(1, 2, 0)[:, :, ::-1])
        input_image += noise
        input_image = np.clip(input_image, 0, 1)

        cv2.imwrite("input_image_noisy.png", (input_image * 255).astype(
            np.uint8).transpose(1, 2, 0)[:, :, ::-1])

        input_image = torch.Tensor(input_image).cuda()
        x = nn.Parameter(torch.Tensor(3, 32, 32).uniform_()).cuda()

        step_lr = 0.00002

        # Noise amounts
        sigmas = np.array([
            1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
            0.04641589, 0.02782559, 0.01668101, 0.01
        ])
        n_steps_each = 100
        lambda_recon = 1.5  # Weight to put on reconstruction error vs p(x)

        for idx, sigma in enumerate(sigmas):
            # Not completely sure what this part is for
            labels = torch.ones(1, device=x.device) * idx
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1])**2

            for step in range(n_steps_each):
                noise_x = torch.randn_like(x) * np.sqrt(step_size * 2)

                grad_x = scorenet(x.view(1, 3, 32, 32), labels).detach()

                recon_loss = (torch.norm(torch.flatten(input_image - x))**2)
                print(recon_loss)
                recon_grads = torch.autograd.grad(recon_loss, [x])

                #x = x + (step_size * grad_x) + noise_x
                x = x + (step_size *
                         grad_x) + (-step_size * lambda_recon *
                                    recon_grads[0].detach()) + noise_x

            lambda_recon *= 1.6

        # # Write x and y
        x_np = x.detach().cpu().numpy()[0, :, :, :]
        x_np = np.clip(x_np, 0, 1)
        cv2.imwrite("x.png",
                    (x_np * 255).astype(np.uint8).transpose(1, 2,
                                                            0)[:, :, ::-1])

        # y_np = y.detach().cpu().numpy()[0,:,:,:]
        # y_np = np.clip(y_np, 0, 1)
        # cv2.imwrite("y.png", (y_np * 255).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1])

        # cv2.imwrite("out_mixed.png", (y_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1] + (x_np * 127.5).astype(np.uint8).transpose(1, 2, 0)[:,:,::-1])

        import pdb
        pdb.set_trace()
Ejemplo n.º 9
0
def main(args):
    ### config
    global noise_multiplier
    dataset = args.dataset
    num_discriminators = args.num_discriminators
    noise_multiplier = args.noise_multiplier
    z_dim = args.z_dim
    if dataset == 'celeba':
        z_dim = 100
    model_dim = args.model_dim
    batchsize = args.batchsize
    L_gp = args.L_gp
    L_epsilon = args.L_epsilon
    critic_iters = args.critic_iters
    latent_type = args.latent_type
    load_dir = args.load_dir
    save_dir = args.save_dir
    if_dp = (args.noise_multiplier > 0.)
    gen_arch = args.gen_arch
    num_gpus = args.num_gpus

    ### CUDA
    use_cuda = torch.cuda.is_available()
    devices = [
        torch.device("cuda:%d" % i if use_cuda else "cpu")
        for i in range(num_gpus)
    ]
    device0 = devices[0]
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    ### Random seed
    if args.random_seed == 1:
        args.random_seed = np.random.randint(10000, size=1)[0]
    print('random_seed: {}'.format(args.random_seed))
    os.system('rm ' + os.path.join(save_dir, 'seed*'))
    os.system('touch ' +
              os.path.join(save_dir, 'seed=%s' % str(args.random_seed)))
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    ### Set up models
    print('gen_arch:' + gen_arch)
    if dataset == 'celeba':
        ngpu = 1
        netG = Generator_celeba(ngpu).to(device0)
        #netG.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netG_15000.pth'))

        # Handle multi-gpu if desired
        if (device0.type == 'cuda') and (ngpu > 1):
            netG = nn.DataParallel(netG, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.02.
        netG.apply(weights_init)

    netGS = copy.deepcopy(netG).to(device0)
    if dataset == 'celeba':
        ngpu = 1
        netD = Discriminator_celeba(ngpu).to(device0)
        #netD.load_state_dict(torch.load('../results/celeba/main/d_1_2e-4_g_1_2e-4_SN_full/netD_15000.pth'))
        # Handle multi-gpu if desired
        if (device0.type == 'cuda') and (ngpu > 1):
            netD = nn.DataParallel(netD, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        #netD.apply(weights_init)

    ### Set up optimizers
    optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.99))
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.99))

    ### Data loaders
    if dataset == 'celeba':
        transform_train = transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    if dataset == 'celeba':
        IMG_DIM = 64 * 64 * 3
        NUM_CLASSES = 2
        trainset = CelebA(
            root=os.path.join('/work/u5366584/exp/datasets/celeba'),
            split='train',
            transform=transform_train,
            download=False)  #, custom_subset=True)
        #trainset = CelebA(root=os.path.join('../data'), split='train',
        #    transform=transform_train, download=False, custom_subset=True)
    else:
        raise NotImplementedError

    ###fix sub-training set (fix to 10000 training samples)
    if args.update_train_dataset:
        if dataset == 'mnist':
            indices_full = np.arange(60000)
        elif dataset == 'cifar_10':
            indices_full = np.arange(50000)
        elif dataset == 'celeba':
            indices_full = np.arange(len(trainset))
        np.random.shuffle(indices_full)
        '''
        #####ref
        indices = np.loadtxt('index_20k.txt', dtype=np.int_)
        remove_idx = [np.argwhere(indices_full==x) for x in indices]
        indices_ref = np.delete(indices_full, remove_idx)
        
        indices_slice = indices_ref[:20000]
        np.savetxt('index_20k_ref.txt', indices_slice, fmt='%i')   ##ref index is disjoint to original index
        '''

        ### growing dataset
        indices = np.loadtxt('index_20k.txt', dtype=np.int_)
        remove_idx = [np.argwhere(indices_full == x) for x in indices]
        indices_rest = np.delete(indices_full, remove_idx)

        indices_rest = indices_rest[:20000]
        indices_slice = np.concatenate((indices, indices_rest), axis=0)
        np.savetxt('index_40k.txt', indices_slice, fmt='%i')
    indices = np.loadtxt('index_100k.txt', dtype=np.int_)
    trainset = torch.utils.data.Subset(trainset, indices)
    print(len(trainset))

    workers = 4
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=batchsize,
                                             shuffle=True,
                                             num_workers=workers)

    if if_dp:
        ### Register hook
        global dynamic_hook_function
        for netD in netD_list:
            netD.conv1.register_backward_hook(master_hook_adder)

    criterion = nn.BCELoss()
    real_label = 1.
    fake_label = 0.
    nz = 100
    fixed_noise = torch.randn(100, nz, 1, 1, device=device0)
    iters = 0
    num_epochs = 256 * 5 + 1

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, y) in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data.to(device0)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device0)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device0)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            iters += 1

            for iter_g in range(1):
                ############################
                # Update G network
                ###########################
                if if_dp:
                    ### Sanitize the gradients passed to the Generator
                    dynamic_hook_function = dp_conv_hook
                else:
                    ### Only modify the gradient norm, without adding noise
                    dynamic_hook_function = modify_gradnorm_conv_hook

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################

                noise = torch.randn(b_size, nz, 1, 1, device=device0)
                fake = netG(noise)
                label = torch.full((b_size, ),
                                   real_label,
                                   dtype=torch.float,
                                   device=device0)

                netG.zero_grad()
                label.fill_(
                    real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = netD(fake).view(-1)
                # Calculate G's loss based on this output
                errG = criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                optimizerG.step()

            ### update the exponential moving average
            exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters)

            ############################
            ### Results visualization
            ############################
            if iters % 10 == 0:
                print('iter:{}, G_cost:{:.2f}, D_cost:{:.2f}'.format(
                    iters,
                    errG.item(),
                    errD.item(),
                ))
            if iters % args.vis_step == 0:
                if dataset == 'celeba':
                    generate_image_celeba(str(iters + 0), netGS, fixed_noise,
                                          save_dir, device0)

            if iters % args.save_step == 0:
                ### save model
                torch.save(
                    netGS.state_dict(),
                    os.path.join(save_dir, 'netGS_%s.pth' % str(iters + 0)))
                torch.save(
                    netD.state_dict(),
                    os.path.join(save_dir, 'netD_%s.pth' % str(iters + 0)))

        torch.cuda.empty_cache()
Ejemplo n.º 10
0
def load_data(dataset, data_path, model_name):
    img_size, _, _ = get_data_info(dataset, model_name)
    if dataset == 'MNIST':
        data_train = MNIST(data_path,
                           transform=transforms.Compose([
                               transforms.Resize((img_size, img_size)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ]),
                           download=True)  # True for the first time
        data_test = MNIST(data_path,
                          train=False,
                          transform=transforms.Compose([
                              transforms.Resize((img_size, img_size)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307, ), (0.3081, ))
                          ]))

    elif dataset == 'SVHN':
        mean = (0.4377, 0.4438, 0.4728)
        std = (0.1980, 0.2010, 0.1970)
        transform_train = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        transform_test = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

        data_train = SVHN(data_path + '/SVHN',
                          split='train',
                          download=True,
                          transform=transform_train)
        data_test = SVHN(data_path + '/SVHN',
                         split='test',
                         download=True,
                         transform=transform_test)

    elif dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        data_train = CIFAR10(data_path,
                             transform=transform_train,
                             download=True)  # True for the first time
        data_test = CIFAR10(data_path, train=False, transform=transform_test)

    elif dataset == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        data_train = CIFAR100(data_path,
                              transform=transform_train,
                              download=True)  # True for the first time
        data_test = CIFAR100(data_path, train=False, transform=transform_test)

    elif dataset == 'FEMNIST':
        from datasets.femnist import FEMNIST

        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        transform = transforms.Compose([
            # transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            # transforms.Normalize(mean, std)
        ])
        data_train = FEMNIST(data_path + '/femnist/',
                             transform=transform,
                             train=True)
        data_test = FEMNIST(data_path + '/femnist/',
                            transform=transform,
                            train=False)

    elif dataset == 'celeba':
        from datasets.celeba import CelebA

        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        transform = transforms.Compose([
            # transforms.RandomCrop((84, 84)),
            transforms.CenterCrop((84, 84)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        data_train = CelebA(data_path + '/celeba/',
                            transform=transform,
                            train=True,
                            read_all_data_to_mem=False)
        data_test = CelebA(data_path + '/celeba/',
                           transform=transform,
                           train=False,
                           read_all_data_to_mem=False)

    elif dataset == 'shakespeare':
        from datasets.shakespeare import SHAKESPEARE
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        data_train = SHAKESPEARE(data_path + '/shakespeare/',
                                 transform=transform,
                                 train=True)
        data_test = SHAKESPEARE(data_path + '/shakespeare/',
                                transform=transform,
                                train=False)

    else:
        raise Exception('Unknown dataset name.')

    return data_train, data_test
Ejemplo n.º 11
0
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.Transpose(), lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0
        ])
    else:
        tran_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.RandomHorizontalFlip(prob=0.5),
            transforms.Transpose(),
            lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0,
        ])
        test_transform = transforms.Compose([
            transforms.Resize([config.data.image_size] * 2),
            transforms.Transpose(), lambda x: x
            if x.dtype != np.uint8 else x.astype('float32') / 255.0
        ])

    if config.data.dataset == "CIFAR10":
        dataset = Cifar10(
            # os.path.join(args.exp, "datasets", "cifar10"),
            mode="train",
            download=True,
            transform=tran_transform,
        )
        test_dataset = Cifar10(
            # os.path.join(args.exp, "datasets", "cifar10_test"),
            mode="test",
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        if config.data.random_flip:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose([
                    Crop(x1, x2, y1, y2),
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.RandomHorizontalFlip(),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
                download=True,
            )
        else:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose([
                    Crop(x1, x2, y1, y2),
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
                download=True,
            )

        test_dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="test",
            transform=transforms.Compose([
                Crop(x1, x2, y1, y2),
                transforms.Resize([config.data.image_size] * 2),
                transforms.Transpose(),
                lambda x: x
                if x.dtype != np.uint8 else x.astype('float32') / 255.0,
            ]),
            download=True,
        )

    elif config.data.dataset == "LSUN":
        train_folder = "{}_train".format(config.data.category)
        val_folder = "{}_val".format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose([
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.CenterCrop((config.data.image_size, ) * 2),
                    transforms.RandomHorizontalFlip(prob=0.5),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
            )
        else:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose([
                    transforms.Resize([config.data.image_size] * 2),
                    transforms.CenterCrop((config.data.image_size, ) * 2),
                    transforms.Transpose(),
                    lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0,
                ]),
            )

        test_dataset = LSUN(
            root=os.path.join(args.exp, "datasets", "lsun"),
            classes=[val_folder],
            transform=transforms.Compose([
                transforms.Resize([config.data.image_size] * 2),
                transforms.CenterCrop((config.data.image_size, ) * 2),
                transforms.Transpose(),
                lambda x: x
                if x.dtype != np.uint8 else x.astype('float32') / 255.0,
            ]),
        )

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(prob=0.5),
                    transforms.Transpose(), lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0
                ]),
                resolution=config.data.image_size,
            )
        else:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose(
                    transforms.Transpose(), lambda x: x
                    if x.dtype != np.uint8 else x.astype('float32') / 255.0),
                resolution=config.data.image_size,
            )

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = (
            indices[:int(num_items * 0.9)],
            indices[int(num_items * 0.9):],
        )
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)
    else:
        dataset, test_dataset = None, None

    return dataset, test_dataset
Ejemplo n.º 12
0
def get_dataset(args, config):
    if config.data.dataset == 'CIFAR10':
        if (config.data.random_flip):
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.RandomHorizontalFlip(p=0.5),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

        else:
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-32px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(32),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-8px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(8),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif config.data.dataset == 'LSUN':
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    elif config.data.dataset == "MNIST":
        if config.data.random_flip:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        else:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor()
                             ]))
    elif config.data.dataset == "USPS":
        if config.data.random_flip:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        else:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        test_dataset = USPS(root=os.path.join('datasets', 'USPS'),
                            train=False,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
    elif config.data.dataset == "USPS-Pad":
        if config.data.random_flip:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        else:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        test_dataset = USPS(
            root=os.path.join('datasets', 'USPS'),
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(20),  # resize and pad like MNIST
                transforms.Pad(4),
                transforms.Resize(config.data.image_size),
                transforms.ToTensor()
            ]))
    elif (config.data.dataset.upper() == "GAUSSIAN"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        if (config.data.isotropic):
            dim = config.data.dim
            rank = config.data.rank
            cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)]))
            mean = np.zeros((dim, ))
        else:
            cov = np.array(config.data.cov)
            mean = np.array(config.data.mean)

        shape = config.data.dataset.shape if hasattr(config.data.dataset,
                                                     "shape") else None

        dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape)
        test_dataset = Gaussian(device=args.device,
                                cov=cov,
                                mean=mean,
                                shape=shape)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        cov = np.load(config.data.cov_path)
        mean = np.load(config.data.mean_path)
        dataset = Gaussian(device=args.device, cov=cov, mean=mean)
        test_dataset = Gaussian(device=args.device, cov=cov, mean=mean)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"):
        # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality
        #   of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow.
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        shape = config.data.shape if hasattr(config.data, "shape") else None
        dataset = Gaussian(device=args.device,
                           mean=None,
                           cov=None,
                           shape=shape,
                           iid_unit=True)
        test_dataset = Gaussian(device=args.device,
                                mean=None,
                                cov=None,
                                shape=shape,
                                iid_unit=True)

    return dataset, test_dataset
Ejemplo n.º 13
0
    def test_classifier(self, input_split='test', save_result=False, model_name=None, labels_filename=None,
                        acc_filename=None, acc_filenames_i=None):
        """Predicts labels and compares them to ground truth labels from given split. Returns test accuracy.
        Args:
            input_split: What split to test on [train|val|test].
            save_result: Optional, boolean. If True saves predicted labels and accuracy.
            model_name:  For neural network classifiers, model name to load and use to predict.
            labels_filename: Optional, string. Path to save predicted labels in.
            acc_filename: Optional, string. Path to save predicted accuracy in.
            acc_filenames_i: Optional, array of strings. Path to save class-specific predicted labels in.

        Returns:
            predicted_labels: Predicted labels for the input split.
            accuracy: Accuracy on the input split.
            per_class_accuracies: Array of per-class accuracies on the input split.

        Raises:
            IOError: If an input error occurs when loading features, or an output error occurs when saving results.
            ValueError: If the specified dataset [mnist|f-mnist|celeba] or classifier type
            [svm|linear-svm|lmnn|logistic|knn|nn] is not supported.
        """

        # If save_result is True, but no labels_filename was specified, use default filename.
        if save_result and (labels_filename is None):
            output_dir = self.get_output_dir()
            labels_filename = self.get_labels_filename(input_split)
            labels_filename = os.path.join(output_dir, labels_filename)

        # If save_result is True, but no acc_filename was specified, use default filename.
        if save_result and (acc_filename is None):
            output_dir = self.get_output_dir()
            acc_filename, acc_filenames_i = self.get_acc_filename(input_split)
            acc_filename = os.path.join(output_dir, acc_filename)
            for i in range(self.classifier_params['num_classes']):
                acc_filenames_i[i] = os.path.join(output_dir, acc_filenames_i[i])

        # Load feature vectors.
        feature_dir = os.path.dirname(self.classifier_params['feature_file'])
        feature_file = os.path.basename(self.classifier_params['feature_file'])
        feature_file = feature_file.replace('train', input_split)
        feature_file = os.path.join(feature_dir, feature_file)

        try:
            with open(feature_file, 'r') as f:
                features = cPickle.load(f)
        except IOError as err:
            print('[!] I/O error({0}): {1}.'.format(err.errno, err.strerror))

        if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
            print('[*] Loaded feature vectors from {}.'.format(feature_file))

        # Initiate dataset object to load ground-truth labels.
        if self.classifier_params['dataset'] == 'mnist':
            ds = Mnist()
        elif self.classifier_params['dataset'] == 'f-mnist':
            ds = FMnist()
        elif self.classifier_params['dataset'] == 'celeba':
            ds = CelebA(resize_size=self.classifier_params['output_height'],
                        attribute=self.classifier_params['attribute'])
        else:
            raise ValueError('[!] Dataset {} is not supported.'.format(self.classifier_params['dataset']))

        # Load ground-truth labels.
        _, labels, _ = ds.load(input_split)
        num_samples = min(np.shape(features)[0], len(labels))
        labels = labels[:num_samples]
        features = features[:num_samples, :]

        if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
            print('[*] Loaded ground-truth labels from: {}.'.format(
                self.classifier_params['dataset']))

        # Predict labels.
        if self.classifier_type in ('svm', 'logistic', 'knn', 'linear-svm', 'lmnn'):
            predicted_labels = self.predict(features, save_result, labels_filename)
        elif self.classifier_type == 'nn':
            predicted_labels = self.predict(features, save_result, model_name, labels_filename)
        else:
            raise ValueError('[!] Classifier type {} is not supported.'.format(self.classifier_type))

        # Compare predicted labels to ground-truth labels and calculate accuracy.
        num_correct = np.sum(np.equal(predicted_labels, labels))
        accuracy = num_correct / (1.0 * len(labels))
        per_class_accuracies = []
        for i in range(self.classifier_params['num_classes']):
            idx = np.where(np.equal(labels, i))[0]
            num_correct = np.sum(np.equal(predicted_labels[idx], labels[idx]))
            accuracy_i = num_correct / (1.0 * len(labels[idx]))
            per_class_accuracies.append(accuracy_i)

        # Save results.
        if save_result:
            try:
                with open(acc_filename, 'w') as fp:
                    fp.write("{}".format(accuracy))
            except IOError as err:
                print("[!] I/O error({0}): {1}.".format(err.errno,
                                                        err.strerror))

            if self.classifier_params.has_key('verbose') and self.classifier_params['verbose']:
                print('[*] Saved predicted labels {}.'.format(labels_filename))
                print('[*] Saved predicted accuracy {}.'.format(acc_filename))

            for i in range(self.classifier_params['num_classes']):
                try:
                    with open(acc_filenames_i[i], 'w') as fp:
                        fp.write("{}".format(per_class_accuracies[i]))
                except IOError as err:
                    print("[!] I/O error({0}): {1}.".format(err.errno,
                                                            err.strerror))

        if self.classifier_params.has_key('verbose') and self.classifier_params['verbose']:
            print('[*] Testing complete. Accuracy on {} split {}.'.format(
                input_split, accuracy))
            for i in range(self.classifier_params['num_classes']):
                print('[*] Testing complete. Accuracy on {} split, class {}: {}.'.format(input_split, i,
                                                                                         per_class_accuracies[i]))

        return predicted_labels, accuracy, per_class_accuracies
Ejemplo n.º 14
0
    def validate(self):
        """Only needed for neural networks. Validates different checkpoints by testing them on the validation split and
        retaining the one with the top accuracy.

        Returns:
            best_model: Name of chosen best model (empty string if no validation was performed). An empty string is
            returned for non neural network classifiers.

        Raises:
            IOError: If an input error occurs when loading feature vectors, or an output error occurs when saving the
            chosen model.
            ValueError: If the specified dataset [mnist|f-mnist|celeba] or classifier type
            [svm|linear-svm|lmnn|logistic|knn|nn] is not supported.
        """

        if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
            print("[*] Validating.")

        # Get feature file paths.
        feature_dir = os.path.dirname(self.classifier_params['feature_file'])
        feature_file = os.path.basename(self.classifier_params['feature_file'])
        feature_file = feature_file.replace('train', 'val')
        feature_file = os.path.join(feature_dir, feature_file)

        # Load feature vectors.
        try:
            with open(feature_file, 'r') as f:
                features = cPickle.load(f)
        except IOError as err:
            print("[!] I/O error({0}): {1}.".format(err.errno, err.strerror))

        if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
            print('[*] Loaded feature vectors from {}.'.format(feature_file))

        # Initialize the dataset object to load ground-truth labels.
        if self.classifier_params['dataset'] == 'mnist':
            ds = Mnist()
        elif self.classifier_params['dataset'] == 'f-mnist':
            ds = FMnist()
        elif self.classifier_params['dataset'] == 'celeba':
            ds = CelebA(resize_size=self.classifier_params['output_height'],
                        attribute=self.classifier_params['attribute'])
        else:
            raise ValueError('[!] Dataset {} is not supported.'.format(self.classifier_params['dataset']))

        # Load ground-truth labels from the validation split.
        _, labels, _ = ds.load('val')
        num_samples = min(np.shape(features)[0], len(labels))
        labels = labels[:num_samples]
        features = features[:num_samples, :]

        if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
            print('[*] Loaded ground-truth labels from {}.'.format(
                self.classifier_params['dataset']))

        # Non neural network classifiers do not require validation as no intermediate models exist.
        if self.classifier_type in ('svm', 'logistic', 'knn', 'linear-svm', 'lmnn'):
            print('[!] No validation needed.')
            return ""

        # Neural network classifiers.
        elif self.classifier_type == 'nn':
            # Call the neural network validate function on the features.
            best_acc, best_model, _ = self.estimator.validate(features, labels, session=self.session)

            # Save results.
            try:
                with open(os.path.join(self.get_output_dir(), self.tf_checkpoint_dir(), 'chosen_model.txt'), 'w') as fp:
                    fp.write("{} {}".format(os.path.basename(best_model), best_acc))
            except IOError as err:
                print("[!] I/O error({0}): {1}.".format(err.errno,
                                                        err.strerror))

            if 'verbose' in self.classifier_params and self.classifier_params['verbose']:
                print(
                    '[*] Chose model: {}, with validation accuracy {}.'.format(os.path.basename(best_model), best_acc))
            return best_model

        else:
            raise ValueError('[!] Classifier type {} is not supported.'.format(self.classifier_type))
Ejemplo n.º 15
0
    def train(self, features=None, labels=None, retrain=False, num_train=-1):
        """Trains classifier using training features and ground truth training labels.

        Args:
            features: Path to training feature vectors (use None to automatically load saved features from experiment
            output directory).
            labels: Path to ground truth train labels (use None to automatically load from dataset).
            retrain: Boolean, whether or not to retrain if classifier is already saved.
            num_train: Number of training samples to use (use -1 to include all training samples).

        Raises:
            ValueError: If the specified dataset [mnist|f-mnist|celeba] or classifier type
            [svm|linear-svm|lmnn|logistic|knn|nn] is not supported.
        """

        # If no feature vector is provided load from experiment output directory.
        if features is None:
            feature_file = self.classifier_params['feature_file']
            try:
                with open(feature_file, 'r') as f:
                    features = cPickle.load(f)
            except IOError as err:
                print("[!] I/O error({0}): {1}.".format(err.errno,
                                                        err.strerror))
            if self.classifier_params.has_key('verbose') and self.classifier_params['verbose']:
                print('[*] Loaded feature file from {}.'.format(feature_file))

        # If no label vector is provided load from dataset.
        if labels is None:
            # Create dataset object based on dataset name.
            if self.classifier_params['dataset'] == 'mnist':
                ds = Mnist()
            elif self.classifier_params['dataset'] == 'f-mnist':
                ds = FMnist()
            elif self.classifier_params['dataset'] == 'celeba':
                ds = CelebA(resize_size=self.classifier_params['output_height'],
                            attribute=self.classifier_params['attribute'])
            else:
                raise ValueError('[!] Dataset {} is not supported.'.format(self.classifier_params['dataset']))
            # Load labels from the train split.
            _, labels, _ = ds.load('train')
            num_samples = min(np.shape(features)[0], len(labels))

            # Restrict to the first num_train samples if num_train is not -1.
            if num_train > -1:
                num_samples = min(num_train, num_samples)

            labels = labels[:num_samples]
            features = features[:num_samples, :]

            if self.classifier_params.has_key('verbose') and self.classifier_params['verbose']:
                print('[*] Loaded ground truth labels from {}.'.format(
                    self.classifier_params['dataset']))

        # Train the classifier.
        if self.classifier_type in ('svm', 'logistic', 'knn', 'linear-svm'):
            self.estimator.fit(features, labels)

        # Neural network classifiers.
        elif self.classifier_type == 'nn':
            self.estimator.fit(features, labels, retrain=retrain, session=self.session)

        # For LMNN, first transform the feature vector then perform k-NN.
        elif self.classifier_type == 'lmnn':
            # Learn the metric.
            self.helper_estimator.fit(features, labels)
            # Transform feature space.
            transformed_features = self.helper_estimator.transform(features)
            # Create k-nn graph.
            self.estimator.fit(transformed_features, labels)

        else:
            raise ValueError('[!] Classifier type {} is not supported.'.format(self.classifier_type))

        if ('verbose' in self.classifier_params) and self.classifier_params['verbose']:
            print('[*] Trained classifier.')
Ejemplo n.º 16
0
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])
    else:
        tran_transform = transforms.Compose([
            transforms.Resize(config.data.image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])

    if config.data.dataset == 'CIFAR10':
        dataset = CIFAR10(os.path.join(args.exp, 'datasets', 'cifar10'),
                          train=True,
                          download=True,
                          transform=tran_transform)
        test_dataset = CIFAR10(os.path.join(args.exp, 'datasets',
                                            'cifar10_test'),
                               train=False,
                               download=True,
                               transform=test_transform)

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=False)
        else:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=False)

        test_dataset = CelebA(root=os.path.join(args.exp, 'datasets',
                                                'celeba_test'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=False)

    elif config.data.dataset == 'LSUN':
        # import ipdb; ipdb.set_trace()
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    return dataset, test_dataset