def get_data(dataname, path, img_size=64):
    if dataname == 'CIFAR10':
        dataset = CIFAR10(path,
                          train=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5))
                          ]),
                          download=True)
        print('CIFAR10')
    elif dataname == 'MNIST':
        dataset = MNIST(path,
                        train=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, ), (0.5, ))
                        ]),
                        download=True)
        print('MNIST')
    elif dataname == 'LSUN-dining':
        dataset = LSUN(path,
                       classes=['dining_room_train'],
                       transform=transforms.Compose([
                           transforms.Resize(img_size),
                           transforms.CenterCrop(img_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))
                       ]))
        print('LSUN-dining')
    elif dataname == 'LSUN-bedroom':
        dataset = LSUN(path,
                       classes=['bedroom_train'],
                       transform=transforms.Compose([
                           transforms.Resize(img_size),
                           transforms.CenterCrop(img_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))
                       ]))
        print('LSUN-bedroom')
    elif dataname == 'CelebA':
        dataset = ImageFolder(root=path,
                              transform=transforms.Compose([
                                  transforms.Resize(64),
                                  transforms.CenterCrop(64),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5),
                                                       (0.5, 0.5, 0.5)),
                              ]))
    return dataset
Beispiel #2
0
def get_lsun_dataset(dataset_path: str, classes: str, image_size: int,
                     mean: Tuple[float, float, float],
                     std: Tuple[float, float, float]) -> Dataset:
    return LSUN(
        root=dataset_path,
        transform=_get_lsun_transform(image_size, mean, std),
        classes=classes,
    )
Beispiel #3
0
def cli_main(args=None):
    seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--dataset",
                        default="mnist",
                        type=str,
                        choices=["lsun", "mnist"])
    parser.add_argument("--data_dir", default="./", type=str)
    parser.add_argument("--image_size", default=64, type=int)
    parser.add_argument("--num_workers", default=8, type=int)

    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "lsun":
        transforms = transform_lib.Compose([
            transform_lib.Resize(script_args.image_size),
            transform_lib.CenterCrop(script_args.image_size),
            transform_lib.ToTensor(),
            transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset = LSUN(root=script_args.data_dir,
                       classes=["bedroom_train"],
                       transform=transforms)
        image_channels = 3
    elif script_args.dataset == "mnist":
        transforms = transform_lib.Compose([
            transform_lib.Resize(script_args.image_size),
            transform_lib.ToTensor(),
            transform_lib.Normalize((0.5, ), (0.5, )),
        ])
        dataset = MNIST(root=script_args.data_dir,
                        download=True,
                        transform=transforms)
        image_channels = 1

    dataloader = DataLoader(dataset,
                            batch_size=script_args.batch_size,
                            shuffle=True,
                            num_workers=script_args.num_workers)

    parser = DCGAN.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(args)

    model = DCGAN(**vars(args), image_channels=image_channels)
    callbacks = [
        TensorboardGenerativeModelImageSampler(num_samples=5),
        LatentDimInterpolator(interpolate_epoch_interval=5),
    ]
    trainer = Trainer.from_argparse_args(args, callbacks=callbacks)
    trainer.fit(model, dataloader)
Beispiel #4
0
def main(W):
    classes = [
        'bedroom', 'kitchen', 'conference_room', 'dining_room',
        'church_outdoor'
    ]
    N = 100000
    image_size = (W, W)
    rootdir = '.'

    for i, c in enumerate(classes):
        print(c)
        lsun = LSUN(rootdir, ['%s_train' % c],
                    transform=transforms.Compose(
                        [transforms.Resize(image_size)]))
        for n in tqdm(range(N)):
            lsun[n][0].save('data_split_%d/%s/0/%07d.jpg' % (W, c, n))
            lsun[n][0].save('data_%d/0/%07d.jpg' % (W, n + i * N))
Beispiel #5
0
def main():
    batch_size = 8

    net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
    snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        expanded_transform.FreeScale((512, 1024)),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    restore = transforms.Compose([
        expanded_transform.DeNormalize(*mean_std),
        transforms.ToPILImage()
    ])

    lsun_path = '/home/b3-542/LSUN'

    dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)

    if not os.path.exists(test_results_path):
        os.mkdir(test_results_path)

    for vi, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        outputs = net(inputs)

        prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()

        for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
            pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
            print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
Beispiel #6
0
    def renew(self, resl):
        # print('[*] Renew dataloader configuration, load data from {}.'.format(self.root))

        self.batchsize = int(self.batch_table[pow(2, resl)])
        self.imsize = int(pow(2, resl))
        # self.dataset = ImageFolder(
        #             root=self.root,
        #             transform=transforms.Compose(   [
        #                                             transforms.Resize(size=(self.imsize,self.imsize), interpolation=Image.NEAREST),
        #                                             transforms.ToTensor(),
        #                                             ]))

        # self.dataloader = DataLoader(
        #     dataset=self.dataset,
        #     batch_size=self.batchsize,
        #     shuffle=True,
        #     num_workers=self.num_workers
        # )
        self.transform = transforms.Compose([
            transforms.Resize(size=(self.imsize, self.imsize),
                              interpolation=Image.NEAREST),
            transforms.ToTensor(),
        ])
        if self.dataset == 'cifar10':
            self.dataset = CIFAR10(self.root,
                                   train=True,
                                   transform=self.transform)
        # elif self.dataset == 'celeba':
        # self.dataset = CelebA(self.root, train=True, transform=self.transform)
        elif self.dataset == 'lsun_bedroom':
            self.dataset = LSUN(self.root,
                                classes=['bedroom_train'],
                                transform=self.transform)
        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_size=self.batchsize,
                                     shuffle=True,
                                     num_workers=self.num_workers)
    def prepare_data(self):
        train_resize = transforms.Resize((self.hparams.image_size, self.hparams.image_size))
        train_normalize = transforms.Normalize(mean=[0.5], std=[0.5])
        train_transform = transforms.Compose([train_resize, transforms.ToTensor(), train_normalize])

        if self.hparams.dataset == "mnist":
            self.train_dataset = MNIST(self.hparams.dataset_path, train=True, download=True, transform=train_transform)
            # self.test_dataset = MNIST(self.hparams.dataset_path, train=False, download=True, transform=test_transform)
        elif self.hparams.dataset == "fashion_mnist":
            self.train_dataset = FashionMNIST(self.hparams.dataset_path, train=True, download=True, transform=train_transform)
            # self.test_dataset = FashionMNIST(self.hparams.dataset_path, train=False, download=True, transform=test_transform)
        elif self.hparams.dataset == "cifar10":
            self.train_dataset = CIFAR10(self.hparams.dataset_path, train=True, download=True, transform=train_transform)
            # self.test_dataset = CIFAR10(self.hparams.dataset_path, train=False, download=True, transform=test_transform)
        elif self.hparams.dataset == "image_net":
            self.train_dataset = ImageNet(self.hparams.dataset_path, train=True, download=True, transform=train_transform)
            # self.test_dataset = ImageNet(self.hparams.dataset_path, train=False, download=True, transform=test_transform)
        elif self.hparams.dataset == "lsun":
            self.train_dataset = LSUN(self.hparams.dataset_path + "/lsun", classes=[cls + "_train" for cls in self.hparams.dataset_classes], transform=train_transform)
            # self.test_dataset = LSUN(self.hparams.dataset_path, classes=[cls + "_test" for cls in self.hparams.dataset_classes], transform=test_transform)
        elif self.hparams.dataset == "celeba_hq":
            self.train_dataset = CelebAHQ(self.hparams.dataset_path, image_size=self.hparams.image_size, transform=train_transform)
        else:
            raise NotImplementedError("Custom dataset is not implemented yet")
Beispiel #8
0
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])
    else:
        tran_transform = transforms.Compose([
            transforms.Resize(config.data.image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])

    if config.data.dataset == 'CIFAR10':
        dataset = CIFAR10(os.path.join(args.exp, 'datasets', 'cifar10'),
                          train=True,
                          download=True,
                          transform=tran_transform)
        test_dataset = CIFAR10(os.path.join(args.exp, 'datasets',
                                            'cifar10_test'),
                               train=False,
                               download=True,
                               transform=test_transform)

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=False)
        else:
            dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=False)

        test_dataset = CelebA(root=os.path.join(args.exp, 'datasets',
                                                'celeba_test'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=False)

    elif config.data.dataset == 'LSUN':
        # import ipdb; ipdb.set_trace()
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    return dataset, test_dataset
def get_image_loader(args, b_size, num_workers):
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if args.dataset == 'cifar10':
        spatial_size = 32

        trainset = CIFAR10(root='./data',
                           train=True,
                           download=True,
                           transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=b_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        n_classes = 10
        testset = CIFAR10(root='./data',
                          train=False,
                          download=True,
                          transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=b_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)
    elif args.dataset == 'lsun':
        spatial_size = 32

        transform_lsun = transforms.Compose([
            transforms.Resize(spatial_size),
            transforms.CenterCrop(spatial_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        trainset = LSUN(root=args.data_path,
                        classes=['bedroom_train'],
                        transform=transform_lsun)
        testset = LSUN(root=args.data_path,
                       classes=['bedroom_val'],
                       transform=transform_lsun)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=b_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=b_size,
                                                 shuffle=True,
                                                 num_workers=num_workers)

    elif args.dataset == 'imagenet32':

        from utils.imagenet import Imagenet32

        spatial_size = 32

        transforms_train = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        trainset = Imagenet32(args.imagenet_train_path,
                              transform=transforms.Compose(transforms_train),
                              sz=spatial_size)
        valset = Imagenet32(args.imagenet_test_path,
                            transform=transforms.Compose(transforms_train),
                            sz=spatial_size)
        n_classes = 1000
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=b_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        testloader = torch.utils.data.DataLoader(valset,
                                                 batch_size=b_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)

    elif args.dataset == 'celebA':
        size = 32
        transform_train = transforms.Compose([
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        trainset = CelebA(root=args.data_path,
                          split='train',
                          transform=transform_train,
                          download=False)
        testset = CelebA(root=args.data_path,
                         split='train',
                         transform=transform_train,
                         download=False)

        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=b_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=b_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)

    else:
        raise NotImplementedError
    return trainloader, testloader, None
Beispiel #10
0
    if DATASET_NAME == 'CIFAR10':
        from torchvision.datasets import CIFAR10
        transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])
        dataset = CIFAR10(root='./datasets',
                          train=True,
                          transform=transforms,
                          download=True)
    elif DATASET_NAME == 'LSUN':
        from torchvision.datasets import LSUN
        transforms = Compose([
            Resize(64),
            ToTensor(),
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        dataset = LSUN(root='./datasets/LSUN',
                       classes=['bedroom_train'],
                       transform=transforms)
    elif DATASET_NAME == 'MNIST':
        from torchvision.datasets import MNIST
        transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])
        dataset = MNIST(root='./datasets',
                        train=True,
                        transform=transforms,
                        download=True)
    else:
        raise NotImplementedError

    data_loader = DataLoader(dataset=dataset,
                             batch_size=BATCH_SIZE,
                             num_workers=0,
                             shuffle=True)
Beispiel #11
0
def main(
    lsun_data_dir: ('Base directory for the LSUN data'),
    image_output_prefix: ('Prefix for image output', 'option', 'o') = 'glo',
    code_dim: ('Dimensionality of latent representation space', 'option', 'd',
               int) = 128,
    epochs: ('Number of epochs to train', 'option', 'e', int) = 25,
    use_cuda: ('Use GPU?', 'flag', 'gpu') = False,
    batch_size: ('Batch size', 'option', 'b', int) = 128,
    lr_g: ('Learning rate for generator', 'option', None, float) = 1.,
    lr_z: ('Learning rate for representation_space', 'option', None,
           float) = 10.,
    max_num_samples: ('Cap on the number of samples from the LSUN dataset',
                      'option', 'n', int) = -1,
    init: ('Initialization strategy for latent represetation vectors',
           'option', 'i', str, ['pca', 'random']) = 'pca',
    n_pca: ('Number of samples to take for PCA', 'option', None,
            int) = (64 * 64 * 3 * 2),
    loss: ('Loss type (Laplacian loss as in the paper, or L2 loss)', 'option',
           'l', str, ['lap_l1', 'l2']) = 'lap_l1',
):
    def maybe_cuda(tensor):
        return tensor.cuda() if use_cuda else tensor

    train_set = IndexedDataset(
        LSUN(lsun_data_dir,
             classes=['bedroom_train'],
             transform=transforms.Compose([
                 transforms.Resize(64),
                 transforms.CenterCrop(64),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
             ])))
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=8,
        pin_memory=use_cuda,
    )
    # we don't really have a validation set here, but for visualization let us
    # just take the first couple images from the dataset
    val_loader = torch.utils.data.DataLoader(train_set,
                                             shuffle=False,
                                             batch_size=8 * 8)

    if max_num_samples > 0:
        train_set.base.length = max_num_samples
        train_set.base.indices = [max_num_samples]

    # initialize representation space:
    if init == 'pca':
        from sklearn.decomposition import PCA

        # first, take a subset of train set to fit the PCA
        X_pca = np.vstack([
            X.cpu().numpy().reshape(len(X), -1) for i, (X, _, _) in zip(
                tqdm(range(n_pca // train_loader.batch_size),
                     'collect data for PCA'), train_loader)
        ])
        print("perform PCA...")
        pca = PCA(n_components=code_dim)
        pca.fit(X_pca)
        # then, initialize latent vectors to the pca projections of the complete dataset
        Z = np.empty((len(train_loader.dataset), code_dim))
        for X, _, idx in tqdm(train_loader, 'pca projection'):
            Z[idx] = pca.transform(X.cpu().numpy().reshape(len(X), -1))

    elif init == 'random':
        Z = np.random.randn(len(train_set), code_dim)

    Z = project_l2_ball(Z)

    g = maybe_cuda(Generator(code_dim))  # initial a Generator g
    loss_fn = LapLoss(max_levels=3) if loss == 'lap_l1' else nn.MSELoss()
    zi = maybe_cuda(torch.zeros((batch_size, code_dim)))
    zi = Variable(zi, requires_grad=True)
    optimizer = SGD([{
        'params': g.parameters(),
        'lr': lr_g
    }, {
        'params': zi,
        'lr': lr_z
    }])

    Xi_val, _, idx_val = next(iter(val_loader))
    imsave(
        'target.png',
        make_grid(Xi_val.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0))

    for epoch in range(epochs):
        losses = []
        progress = tqdm(total=len(train_loader), desc='epoch % 3d' % epoch)

        for i, (Xi, yi, idx) in enumerate(train_loader):
            Xi = Variable(maybe_cuda(Xi))
            zi.data = maybe_cuda(torch.FloatTensor(Z[idx.numpy()]))

            optimizer.zero_grad()
            rec = g(zi)
            loss = loss_fn(rec, Xi)
            loss.backward()
            optimizer.step()

            Z[idx.numpy()] = project_l2_ball(zi.data.cpu().numpy())

            losses.append(loss.data[0])
            progress.set_postfix({'loss': np.mean(losses[-100:])})
            progress.update()

        progress.close()

        # visualize reconstructions
        rec = g(Variable(maybe_cuda(torch.FloatTensor(Z[idx_val.numpy()]))))
        imsave(
            '%s_rec_epoch_%03d.png' % (image_output_prefix, epoch),
            make_grid(rec.data.cpu() / 2. + 0.5,
                      nrow=8).numpy().transpose(1, 2, 0))
Beispiel #12
0
f.write(spec)
f.close()

fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1)  # fixed noise
fixed_z_ = Variable(fixed_z_.cuda(gpu_id), volatile=True)

# data_loader
transform = transforms.Compose([
    transforms.Resize([img_size, img_size], Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

if 'LSUN' in data_name:
    from torchvision.datasets import LSUN
    dset = LSUN('data/LSUN', classes=[data_dir], transform=transform)
else:
    dset = datasets.ImageFolder(data_dir, transform)

train_loader = torch.utils.data.DataLoader(dset,
                                           batch_size=batch_size,
                                           shuffle=True)
# temp = plt.imread(train_loader.dataset.imgs[0][0])

# network
vgg = Vgg19()
vgg.cuda(gpu_id)

G = generator(d=128, mlp_dim=256, s_dim=40, img_size=img_size)
D = discriminator(128, img_size=img_size)
Beispiel #13
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Amortized approximation on Cifar10')
    parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--approx-epochs', type=int, default=200, metavar='N',
                        help='number of epochs to approx (default: 10)')
    parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--dropout-rate', type=float, default=0.5, metavar='p_drop',
                        help='dropout rate')
    parser.add_argument('--model-path', type=str, default='./checkpoint/', metavar='N',
                        help='path where the model params are saved.')
    parser.add_argument('--from-approx-model', type=int, default=1,
                        help='if our model is loaded or trained')
    parser.add_argument('--test-ood-from-disk', type=int, default=1,
                        help='generate test samples or load from disk')
    parser.add_argument('--ood-name', type=str, default='lsun',
                        help='name of the used ood dataset')

    ood = 'lsun'

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': False} if use_cuda else {}

    trans_norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        trans_norm,
    ])

    transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        trans_norm,
    ])

    tr_data = CIFAR10(root='../../data', train=True, download=True, transform=transform_train)

    te_data = CIFAR10(root='../../data', train=False, download=True, transform=transform_test)

    ood_data = LSUN(root='../../data', classes='test', transform=transform_test)

    train_loader = torch.utils.data.DataLoader(
        tr_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        te_data,
        batch_size=args.batch_size, shuffle=False,  **kwargs)

    ood_loader = torch.utils.data.DataLoader(
        ood_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    model = VGG('VGG19').to(device)

    model.load_state_dict(torch.load('checkpoint/ckpt.pth')['net'])

    test(args, model, device, test_loader)

    if args.from_approx_model == 0:
        output_samples = torch.load('./cifar10-vgg19rand-tr-samples.pt')

    # --------------- training approx ---------

    print('approximating ...')
    fmodel = VGG('VGG19').to(device)
    gmodel = VGG('VGG19', concentration=True).to(device)

    if args.from_approx_model == 0:
        f_optimizer = optim.SGD(fmodel.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
        g_optimizer = optim.SGD(gmodel.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)

        best_acc = 0
        for epoch in range(1, args.approx_epochs + 1):
            train_approx(args, fmodel, gmodel, device, train_loader, f_optimizer, g_optimizer, output_samples, epoch)
            acc = test(args, fmodel, device, test_loader)
            if acc > best_acc:
                torch.save(fmodel.state_dict(), args.model_path + 'cifar10rand-mean-mmd.pt')
                torch.save(gmodel.state_dict(), args.model_path + 'cifar10rand-conc-mmd.pt')
                best_acc = acc

    else:
        fmodel.load_state_dict(torch.load(args.model_path + 'cifar10rand-mean-mmd.pt'))
        gmodel.load_state_dict(torch.load(args.model_path + 'cifar10rand-conc-mmd.pt'))


    teacher_test_samples = torch.load('./cifar10-vgg19rand-te-samples.pt')

    teacher_ood_samples  = torch.load('./cifar10-vgg19rand-lsun-samples.pt')

    # fitting individual Dirichlet is not in the sample code as it's time-consuming
    eval_approx(args, fmodel, gmodel, device, test_loader, ood_loader, teacher_test_samples, teacher_ood_samples)
        for ind in indices:
            plot_data.append(data[ind])

    plot_data = torch.stack(plot_data, dim=0)
    save_image(plot_data, '{}.png'.format(name), nrow=k + 1)


if __name__ == '__main__':
    args = parser.parse_args()
    if args.dataset == 'church':
        transforms = Compose([
            Resize(96),
            CenterCrop(96),
            ToTensor()
        ])
        dataset = LSUN('exp/datasets/lsun', ['church_outdoor_train'], transform=transforms)

    elif args.dataset == 'tower' or args.dataset == 'bedroom':
        transforms = Compose([
            Resize(128),
            CenterCrop(128),
            ToTensor()
        ])
        dataset = LSUN('exp/datasets/lsun', ['{}_train'.format(args.dataset)], transform=transforms)

    elif args.dataset == 'celeba':
        transforms = Compose([
            CenterCrop(140),
            Resize(64),
            ToTensor(),
        ])
Beispiel #15
0
def get_dataset(d_config, data_folder):
    cmp = lambda x: transforms.Compose([*x])

    if d_config.dataset == 'CIFAR10':

        train_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(1, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'CIFAR10')
        dataset = CIFAR10(path,
                          train=True,
                          download=True,
                          transform=cmp(train_transform))
        test_dataset = CIFAR10(path,
                               train=False,
                               download=True,
                               transform=cmp(test_transform))

    elif d_config.dataset == 'CELEBA':

        train_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.CenterCrop(140),
            transforms.Resize(d_config.image_size),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'celeba')
        dataset = CelebA(path,
                         split='train',
                         transform=cmp(train_transform),
                         download=True)
        test_dataset = CelebA(path,
                              split='test',
                              transform=cmp(test_transform),
                              download=True)

    elif d_config.dataset == 'Stacked_MNIST':

        dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                  'stackedmnist_train'),
                                load=False,
                                source_root=data_folder,
                                train=True)
        test_dataset = Stacked_MNIST(root=os.path.join(data_folder,
                                                       'stackedmnist_test'),
                                     load=False,
                                     source_root=data_folder,
                                     train=False)

    elif d_config.dataset == 'LSUN':

        ims = d_config.image_size
        train_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        test_transform = [
            transforms.Resize(ims),
            transforms.CenterCrop(ims),
            transforms.ToTensor()
        ]
        if d_config.random_flip:
            train_transform.insert(2, transforms.RandomHorizontalFlip())

        path = data_folder
        dataset = LSUN(path,
                       classes=[d_config.category + "_train"],
                       transform=cmp(train_transform))
        test_dataset = LSUN(path,
                            classes=[d_config.category + "_val"],
                            transform=cmp(test_transform))

    elif d_config.dataset == "FFHQ":

        train_transform = [transforms.ToTensor()]
        test_transform = [transforms.ToTensor()]
        if d_config.random_flip:
            train_transform.insert(0, transforms.RandomHorizontalFlip())

        path = os.path.join(data_folder, 'FFHQ')
        dataset = FFHQ(path,
                       transform=train_transform,
                       resolution=d_config.image_size)
        test_dataset = FFHQ(path,
                            transform=test_transform,
                            resolution=d_config.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        dataset = Subset(dataset, train_indices)
        test_dataset = Subset(test_dataset, test_indices)

    else:
        raise ValueError("Dataset [" + d_config.dataset + "] not configured.")

    return dataset, test_dataset
Beispiel #16
0
def get_dataset(args, config):
    if config.data.dataset == 'CIFAR10':
        if (config.data.random_flip):
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.RandomHorizontalFlip(p=0.5),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

        else:
            dataset = CIFAR10(os.path.join('datasets', 'cifar10'),
                              train=True,
                              download=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor()
                              ]))
            test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(
                                           config.data.image_size),
                                       transforms.ToTensor()
                                   ]))

    elif config.data.dataset == 'CELEBA':
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-32px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(32),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(32),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif (config.data.dataset == "CELEBA-8px"):
        if config.data.random_flip:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                             ]),
                             download=True)
        else:
            dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                             split='train',
                             transform=transforms.Compose([
                                 transforms.CenterCrop(140),
                                 transforms.Resize(8),
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor(),
                             ]),
                             download=True)

        test_dataset = CelebA(root=os.path.join('datasets', 'celeba'),
                              split='test',
                              transform=transforms.Compose([
                                  transforms.CenterCrop(140),
                                  transforms.Resize(8),
                                  transforms.Resize(config.data.image_size),
                                  transforms.ToTensor(),
                              ]),
                              download=True)

    elif config.data.dataset == 'LSUN':
        train_folder = '{}_train'.format(config.data.category)
        val_folder = '{}_val'.format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor(),
                           ]))
        else:
            dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                           classes=[train_folder],
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.CenterCrop(config.data.image_size),
                               transforms.ToTensor(),
                           ]))

        test_dataset = LSUN(root=os.path.join('datasets', 'lsun'),
                            classes=[val_folder],
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.CenterCrop(config.data.image_size),
                                transforms.ToTensor(),
                            ]))

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.ToTensor()
                           ]),
                           resolution=config.data.image_size)
        else:
            dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'),
                           transform=transforms.ToTensor(),
                           resolution=config.data.image_size)

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = indices[:int(num_items * 0.9
                                                   )], indices[int(num_items *
                                                                   0.9):]
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)

    elif config.data.dataset == "MNIST":
        if config.data.random_flip:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        else:
            dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                            train=True,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
        test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'),
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(config.data.image_size),
                                 transforms.ToTensor()
                             ]))
    elif config.data.dataset == "USPS":
        if config.data.random_flip:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.RandomHorizontalFlip(p=0.5),
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        else:
            dataset = USPS(root=os.path.join('datasets', 'USPS'),
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(config.data.image_size),
                               transforms.ToTensor()
                           ]))
        test_dataset = USPS(root=os.path.join('datasets', 'USPS'),
                            train=False,
                            download=True,
                            transform=transforms.Compose([
                                transforms.Resize(config.data.image_size),
                                transforms.ToTensor()
                            ]))
    elif config.data.dataset == "USPS-Pad":
        if config.data.random_flip:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        else:
            dataset = USPS(
                root=os.path.join('datasets', 'USPS'),
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(20),  # resize and pad like MNIST
                    transforms.Pad(4),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor()
                ]))
        test_dataset = USPS(
            root=os.path.join('datasets', 'USPS'),
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(20),  # resize and pad like MNIST
                transforms.Pad(4),
                transforms.Resize(config.data.image_size),
                transforms.ToTensor()
            ]))
    elif (config.data.dataset.upper() == "GAUSSIAN"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        if (config.data.isotropic):
            dim = config.data.dim
            rank = config.data.rank
            cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)]))
            mean = np.zeros((dim, ))
        else:
            cov = np.array(config.data.cov)
            mean = np.array(config.data.mean)

        shape = config.data.dataset.shape if hasattr(config.data.dataset,
                                                     "shape") else None

        dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape)
        test_dataset = Gaussian(device=args.device,
                                cov=cov,
                                mean=mean,
                                shape=shape)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD"):
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        cov = np.load(config.data.cov_path)
        mean = np.load(config.data.mean_path)
        dataset = Gaussian(device=args.device, cov=cov, mean=mean)
        test_dataset = Gaussian(device=args.device, cov=cov, mean=mean)

    elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"):
        # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality
        #   of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow.
        if (config.data.num_workers != 0):
            raise ValueError(
                "If using a Gaussian dataset, num_workers must be zero. \
            Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error."
            )
        shape = config.data.shape if hasattr(config.data, "shape") else None
        dataset = Gaussian(device=args.device,
                           mean=None,
                           cov=None,
                           shape=shape,
                           iid_unit=True)
        test_dataset = Gaussian(device=args.device,
                                mean=None,
                                cov=None,
                                shape=shape,
                                iid_unit=True)

    return dataset, test_dataset
Beispiel #17
0
train_config = {
    'num_iter': 1000000,
    'batch_size': 128,
    'image_size': 64,
    'iter_per_tick': 1000,
    'ticks_per_snapshot': 10,
    'device': torch.device("cuda:0"),
}

run_name = 'LSUN_bedroom'
lsun_dir = '../lsun/'
result_dir = 'results/'

if not os.path.isdir(result_dir):
    os.mkdir(result_dir)

dataset = LSUN(root=lsun_dir, classes=['bedroom_train'], transform=transforms.Compose([
    transforms.Resize(train_config['image_size']),
    transforms.CenterCrop(train_config['image_size']),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))

train_config['dataloader'] = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True, num_workers=0)

train_session = TrainSession(run_name, result_dir)

train_session.training_loop(**train_config)
        
Beispiel #18
0
def main(args):

    def maybe_cuda(tensor):
        return tensor.cuda() if args.gpu else tensor

    Img_dir = 'Figs/'
    if not os.path.exists(Img_dir):
        os.makedirs(Img_dir)

    Model_dir = 'Models/'

    if not os.path.exists(Model_dir):
        os.makedirs(Model_dir)

    Data_dir = 'Data/'

    if not os.path.exists(Data_dir):
        os.makedirs(Data_dir)

    train_set = utils.IndexedDataset(
        LSUN(args.dir, classes=[args.cl+'_train'],
             transform=transforms.Compose([
                 transforms.Resize(64),
                 transforms.CenterCrop(64),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
             ]))
    )
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size,
        shuffle=True, drop_last=True,
        num_workers=8, pin_memory=args.gpu,
    )
    
    val_loader = torch.utils.data.DataLoader(train_set, shuffle=False, batch_size=8*8)

    if args.n > 0:
        train_set.base.length = args.n
        train_set.base.indices = [args.n]

    # initialize representation space:
    if args.init == 'pca':
        print('Check if PCA is already calculated...')
        pca_path = 'Data/GLO_pca_init_{}_{}.pt'.format(
            args.cl, args.d)
        if os.path.isfile(pca_path):
            print(
                '[Latent Init] PCA already calculated before and saved at {}'.
                    format(pca_path))
            Z = torch.load(pca_path)
        else:
            from sklearn.decomposition import PCA

            # first, take a subset of train set to fit the PCA
            X_pca = np.vstack([
                X.cpu().numpy().reshape(len(X), -1)
                for i, (X, _, _)
                in zip(tqdm(range(args.n_pca // train_loader.batch_size), 'collect data for PCA'),
                       train_loader)
            ])
            print("perform PCA...")
            pca = PCA(n_components=args.d)
            pca.fit(X_pca)
            # then, initialize latent vectors to the pca projections of the complete dataset
            Z = np.empty((len(train_loader.dataset), args.d))
            for X, _, idx in tqdm(train_loader, 'pca projection'):
                Z[idx] = pca.transform(X.cpu().numpy().reshape(len(X), -1))


    elif args.init == 'random':
        Z = np.random.randn(len(train_set), args.d)

    Z = utils.project_l2_ball(Z)

    model_generator = maybe_cuda(models.Generator(args.d))
    loss_fn = utils.LapLoss(max_levels=3) if args.loss == 'lap_l1' else nn.MSELoss()
    zi = maybe_cuda(torch.zeros((args.batch_size, args.d)))
    zi = Variable(zi, requires_grad=True)
    optimizer = SGD([
        {'params': model_generator.parameters(), 'lr': args.lr_g},
        {'params': zi, 'lr': args.lr_z}
    ])

    Xi_val, _, idx_val = next(iter(val_loader))
    utils.imsave(Img_dir+'target_%s_%s.png' % (args.cl,args.prfx),
           make_grid(Xi_val.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0))

    for epoch in range(args.e):
        losses = []
        progress = tqdm(total=len(train_loader), desc='epoch % 3d' % epoch)

        for i, (Xi, yi, idx) in enumerate(train_loader):
            Xi = Variable(maybe_cuda(Xi))
            zi.data = maybe_cuda(torch.FloatTensor(Z[idx.numpy()]))

            optimizer.zero_grad()
            rec = model_generator(zi)
            loss = loss_fn(rec, Xi)
            loss.backward()
            optimizer.step()

            Z[idx.numpy()] = utils.project_l2_ball(zi.data.cpu().numpy())

            losses.append(loss.data[0])
            progress.set_postfix({'loss': np.mean(losses[-100:])})
            progress.update()


        progress.close()

        # visualize reconstructions
        rec = model_generator(Variable(maybe_cuda(torch.FloatTensor(Z[idx_val.numpy()]))))
        utils.imsave(Img_dir+'%s_%s_rec_epoch_%03d_%s.png' % (args.cl,args.prfx,epoch, args.init),
               make_grid(rec.data.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0))
    print('Saving the model : epoch % 3d'%epoch)
    utils.save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model_generator.state_dict(),
    },
        Model_dir + 'Glo_{}_z_{}_epch_{}_init_{}.pt'.format(args.cl,args.d, epoch, args.init))