def get_data(dataname, path, img_size=64): if dataname == 'CIFAR10': dataset = CIFAR10(path, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]), download=True) print('CIFAR10') elif dataname == 'MNIST': dataset = MNIST(path, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]), download=True) print('MNIST') elif dataname == 'LSUN-dining': dataset = LSUN(path, classes=['dining_room_train'], transform=transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) print('LSUN-dining') elif dataname == 'LSUN-bedroom': dataset = LSUN(path, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) print('LSUN-bedroom') elif dataname == 'CelebA': dataset = ImageFolder(root=path, transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) return dataset
def get_lsun_dataset(dataset_path: str, classes: str, image_size: int, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> Dataset: return LSUN( root=dataset_path, transform=_get_lsun_transform(image_size, mean, std), classes=classes, )
def cli_main(args=None): seed_everything(1234) parser = ArgumentParser() parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"]) parser.add_argument("--data_dir", default="./", type=str) parser.add_argument("--image_size", default=64, type=int) parser.add_argument("--num_workers", default=8, type=int) script_args, _ = parser.parse_known_args(args) if script_args.dataset == "lsun": transforms = transform_lib.Compose([ transform_lib.Resize(script_args.image_size), transform_lib.CenterCrop(script_args.image_size), transform_lib.ToTensor(), transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = LSUN(root=script_args.data_dir, classes=["bedroom_train"], transform=transforms) image_channels = 3 elif script_args.dataset == "mnist": transforms = transform_lib.Compose([ transform_lib.Resize(script_args.image_size), transform_lib.ToTensor(), transform_lib.Normalize((0.5, ), (0.5, )), ]) dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms) image_channels = 1 dataloader = DataLoader(dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers) parser = DCGAN.add_model_specific_args(parser) parser = Trainer.add_argparse_args(parser) args = parser.parse_args(args) model = DCGAN(**vars(args), image_channels=image_channels) callbacks = [ TensorboardGenerativeModelImageSampler(num_samples=5), LatentDimInterpolator(interpolate_epoch_interval=5), ] trainer = Trainer.from_argparse_args(args, callbacks=callbacks) trainer.fit(model, dataloader)
def main(W): classes = [ 'bedroom', 'kitchen', 'conference_room', 'dining_room', 'church_outdoor' ] N = 100000 image_size = (W, W) rootdir = '.' for i, c in enumerate(classes): print(c) lsun = LSUN(rootdir, ['%s_train' % c], transform=transforms.Compose( [transforms.Resize(image_size)])) for n in tqdm(range(N)): lsun[n][0].save('data_split_%d/%s/0/%07d.jpg' % (W, c, n)) lsun[n][0].save('data_%d/0/%07d.jpg' % (W, n + i * N))
def main(): batch_size = 8 net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda() snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth' net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot))) net.eval() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transform = transforms.Compose([ expanded_transform.FreeScale((512, 1024)), transforms.ToTensor(), transforms.Normalize(*mean_std) ]) restore = transforms.Compose([ expanded_transform.DeNormalize(*mean_std), transforms.ToPILImage() ]) lsun_path = '/home/b3-542/LSUN' dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True) if not os.path.exists(test_results_path): os.mkdir(test_results_path) for vi, data in enumerate(dataloader, 0): inputs, labels = data inputs = Variable(inputs, volatile=True).cuda() outputs = net(inputs) prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy() for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)): pil_input = restore(tensor[0]) pil_output = colorize_mask(tensor[1]) pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx))) pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx))) print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
def renew(self, resl): # print('[*] Renew dataloader configuration, load data from {}.'.format(self.root)) self.batchsize = int(self.batch_table[pow(2, resl)]) self.imsize = int(pow(2, resl)) # self.dataset = ImageFolder( # root=self.root, # transform=transforms.Compose( [ # transforms.Resize(size=(self.imsize,self.imsize), interpolation=Image.NEAREST), # transforms.ToTensor(), # ])) # self.dataloader = DataLoader( # dataset=self.dataset, # batch_size=self.batchsize, # shuffle=True, # num_workers=self.num_workers # ) self.transform = transforms.Compose([ transforms.Resize(size=(self.imsize, self.imsize), interpolation=Image.NEAREST), transforms.ToTensor(), ]) if self.dataset == 'cifar10': self.dataset = CIFAR10(self.root, train=True, transform=self.transform) # elif self.dataset == 'celeba': # self.dataset = CelebA(self.root, train=True, transform=self.transform) elif self.dataset == 'lsun_bedroom': self.dataset = LSUN(self.root, classes=['bedroom_train'], transform=self.transform) self.dataloader = DataLoader(dataset=self.dataset, batch_size=self.batchsize, shuffle=True, num_workers=self.num_workers)
def prepare_data(self): train_resize = transforms.Resize((self.hparams.image_size, self.hparams.image_size)) train_normalize = transforms.Normalize(mean=[0.5], std=[0.5]) train_transform = transforms.Compose([train_resize, transforms.ToTensor(), train_normalize]) if self.hparams.dataset == "mnist": self.train_dataset = MNIST(self.hparams.dataset_path, train=True, download=True, transform=train_transform) # self.test_dataset = MNIST(self.hparams.dataset_path, train=False, download=True, transform=test_transform) elif self.hparams.dataset == "fashion_mnist": self.train_dataset = FashionMNIST(self.hparams.dataset_path, train=True, download=True, transform=train_transform) # self.test_dataset = FashionMNIST(self.hparams.dataset_path, train=False, download=True, transform=test_transform) elif self.hparams.dataset == "cifar10": self.train_dataset = CIFAR10(self.hparams.dataset_path, train=True, download=True, transform=train_transform) # self.test_dataset = CIFAR10(self.hparams.dataset_path, train=False, download=True, transform=test_transform) elif self.hparams.dataset == "image_net": self.train_dataset = ImageNet(self.hparams.dataset_path, train=True, download=True, transform=train_transform) # self.test_dataset = ImageNet(self.hparams.dataset_path, train=False, download=True, transform=test_transform) elif self.hparams.dataset == "lsun": self.train_dataset = LSUN(self.hparams.dataset_path + "/lsun", classes=[cls + "_train" for cls in self.hparams.dataset_classes], transform=train_transform) # self.test_dataset = LSUN(self.hparams.dataset_path, classes=[cls + "_test" for cls in self.hparams.dataset_classes], transform=test_transform) elif self.hparams.dataset == "celeba_hq": self.train_dataset = CelebAHQ(self.hparams.dataset_path, image_size=self.hparams.image_size, transform=train_transform) else: raise NotImplementedError("Custom dataset is not implemented yet")
def get_dataset(args, config): if config.data.random_flip is False: tran_transform = test_transform = transforms.Compose( [transforms.Resize(config.data.image_size), transforms.ToTensor()]) else: tran_transform = transforms.Compose([ transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) test_transform = transforms.Compose( [transforms.Resize(config.data.image_size), transforms.ToTensor()]) if config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(args.exp, 'datasets', 'cifar10'), train=True, download=True, transform=tran_transform) test_dataset = CIFAR10(os.path.join(args.exp, 'datasets', 'cifar10_test'), train=False, download=True, transform=test_transform) elif config.data.dataset == 'CELEBA': if config.data.random_flip: dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=False) else: dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=False) test_dataset = CelebA(root=os.path.join(args.exp, 'datasets', 'celeba_test'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=False) elif config.data.dataset == 'LSUN': # import ipdb; ipdb.set_trace() train_folder = '{}_train'.format(config.data.category) val_folder = '{}_val'.format(config.data.category) if config.data.random_flip: dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), ])) else: dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) test_dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'), classes=[val_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) elif config.data.dataset == "FFHQ": if config.data.random_flip: dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'), transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]), resolution=config.data.image_size) else: dataset = FFHQ(path=os.path.join(args.exp, 'datasets', 'FFHQ'), transform=transforms.ToTensor(), resolution=config.data.image_size) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.9 )], indices[int(num_items * 0.9):] test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) return dataset, test_dataset
def get_image_loader(args, b_size, num_workers): transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'cifar10': spatial_size = 32 trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_size, shuffle=True, num_workers=num_workers) n_classes = 10 testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=b_size, shuffle=False, num_workers=num_workers) elif args.dataset == 'lsun': spatial_size = 32 transform_lsun = transforms.Compose([ transforms.Resize(spatial_size), transforms.CenterCrop(spatial_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = LSUN(root=args.data_path, classes=['bedroom_train'], transform=transform_lsun) testset = LSUN(root=args.data_path, classes=['bedroom_val'], transform=transform_lsun) trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_size, shuffle=True, num_workers=num_workers) testloader = torch.utils.data.DataLoader(testset, batch_size=b_size, shuffle=True, num_workers=num_workers) elif args.dataset == 'imagenet32': from utils.imagenet import Imagenet32 spatial_size = 32 transforms_train = [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] trainset = Imagenet32(args.imagenet_train_path, transform=transforms.Compose(transforms_train), sz=spatial_size) valset = Imagenet32(args.imagenet_test_path, transform=transforms.Compose(transforms_train), sz=spatial_size) n_classes = 1000 trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_size, shuffle=True, num_workers=num_workers) testloader = torch.utils.data.DataLoader(valset, batch_size=b_size, shuffle=False, num_workers=num_workers) elif args.dataset == 'celebA': size = 32 transform_train = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) transform_test = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = CelebA(root=args.data_path, split='train', transform=transform_train, download=False) testset = CelebA(root=args.data_path, split='train', transform=transform_train, download=False) trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_size, shuffle=True, num_workers=num_workers) testloader = torch.utils.data.DataLoader(testset, batch_size=b_size, shuffle=False, num_workers=num_workers) else: raise NotImplementedError return trainloader, testloader, None
if DATASET_NAME == 'CIFAR10': from torchvision.datasets import CIFAR10 transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])]) dataset = CIFAR10(root='./datasets', train=True, transform=transforms, download=True) elif DATASET_NAME == 'LSUN': from torchvision.datasets import LSUN transforms = Compose([ Resize(64), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset = LSUN(root='./datasets/LSUN', classes=['bedroom_train'], transform=transforms) elif DATASET_NAME == 'MNIST': from torchvision.datasets import MNIST transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])]) dataset = MNIST(root='./datasets', train=True, transform=transforms, download=True) else: raise NotImplementedError data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)
def main( lsun_data_dir: ('Base directory for the LSUN data'), image_output_prefix: ('Prefix for image output', 'option', 'o') = 'glo', code_dim: ('Dimensionality of latent representation space', 'option', 'd', int) = 128, epochs: ('Number of epochs to train', 'option', 'e', int) = 25, use_cuda: ('Use GPU?', 'flag', 'gpu') = False, batch_size: ('Batch size', 'option', 'b', int) = 128, lr_g: ('Learning rate for generator', 'option', None, float) = 1., lr_z: ('Learning rate for representation_space', 'option', None, float) = 10., max_num_samples: ('Cap on the number of samples from the LSUN dataset', 'option', 'n', int) = -1, init: ('Initialization strategy for latent represetation vectors', 'option', 'i', str, ['pca', 'random']) = 'pca', n_pca: ('Number of samples to take for PCA', 'option', None, int) = (64 * 64 * 3 * 2), loss: ('Loss type (Laplacian loss as in the paper, or L2 loss)', 'option', 'l', str, ['lap_l1', 'l2']) = 'lap_l1', ): def maybe_cuda(tensor): return tensor.cuda() if use_cuda else tensor train_set = IndexedDataset( LSUN(lsun_data_dir, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8, pin_memory=use_cuda, ) # we don't really have a validation set here, but for visualization let us # just take the first couple images from the dataset val_loader = torch.utils.data.DataLoader(train_set, shuffle=False, batch_size=8 * 8) if max_num_samples > 0: train_set.base.length = max_num_samples train_set.base.indices = [max_num_samples] # initialize representation space: if init == 'pca': from sklearn.decomposition import PCA # first, take a subset of train set to fit the PCA X_pca = np.vstack([ X.cpu().numpy().reshape(len(X), -1) for i, (X, _, _) in zip( tqdm(range(n_pca // train_loader.batch_size), 'collect data for PCA'), train_loader) ]) print("perform PCA...") pca = PCA(n_components=code_dim) pca.fit(X_pca) # then, initialize latent vectors to the pca projections of the complete dataset Z = np.empty((len(train_loader.dataset), code_dim)) for X, _, idx in tqdm(train_loader, 'pca projection'): Z[idx] = pca.transform(X.cpu().numpy().reshape(len(X), -1)) elif init == 'random': Z = np.random.randn(len(train_set), code_dim) Z = project_l2_ball(Z) g = maybe_cuda(Generator(code_dim)) # initial a Generator g loss_fn = LapLoss(max_levels=3) if loss == 'lap_l1' else nn.MSELoss() zi = maybe_cuda(torch.zeros((batch_size, code_dim))) zi = Variable(zi, requires_grad=True) optimizer = SGD([{ 'params': g.parameters(), 'lr': lr_g }, { 'params': zi, 'lr': lr_z }]) Xi_val, _, idx_val = next(iter(val_loader)) imsave( 'target.png', make_grid(Xi_val.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0)) for epoch in range(epochs): losses = [] progress = tqdm(total=len(train_loader), desc='epoch % 3d' % epoch) for i, (Xi, yi, idx) in enumerate(train_loader): Xi = Variable(maybe_cuda(Xi)) zi.data = maybe_cuda(torch.FloatTensor(Z[idx.numpy()])) optimizer.zero_grad() rec = g(zi) loss = loss_fn(rec, Xi) loss.backward() optimizer.step() Z[idx.numpy()] = project_l2_ball(zi.data.cpu().numpy()) losses.append(loss.data[0]) progress.set_postfix({'loss': np.mean(losses[-100:])}) progress.update() progress.close() # visualize reconstructions rec = g(Variable(maybe_cuda(torch.FloatTensor(Z[idx_val.numpy()])))) imsave( '%s_rec_epoch_%03d.png' % (image_output_prefix, epoch), make_grid(rec.data.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0))
f.write(spec) f.close() fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1) # fixed noise fixed_z_ = Variable(fixed_z_.cuda(gpu_id), volatile=True) # data_loader transform = transforms.Compose([ transforms.Resize([img_size, img_size], Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if 'LSUN' in data_name: from torchvision.datasets import LSUN dset = LSUN('data/LSUN', classes=[data_dir], transform=transform) else: dset = datasets.ImageFolder(data_dir, transform) train_loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=True) # temp = plt.imread(train_loader.dataset.imgs[0][0]) # network vgg = Vgg19() vgg.cuda(gpu_id) G = generator(d=128, mlp_dim=256, s_dim=40, img_size=img_size) D = discriminator(128, img_size=img_size)
def main(): # Training settings parser = argparse.ArgumentParser(description='Amortized approximation on Cifar10') parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--approx-epochs', type=int, default=200, metavar='N', help='number of epochs to approx (default: 10)') parser.add_argument('--lr', type=float, default=1e-2, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--dropout-rate', type=float, default=0.5, metavar='p_drop', help='dropout rate') parser.add_argument('--model-path', type=str, default='./checkpoint/', metavar='N', help='path where the model params are saved.') parser.add_argument('--from-approx-model', type=int, default=1, help='if our model is loaded or trained') parser.add_argument('--test-ood-from-disk', type=int, default=1, help='generate test samples or load from disk') parser.add_argument('--ood-name', type=str, default='lsun', help='name of the used ood dataset') ood = 'lsun' args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 8, 'pin_memory': False} if use_cuda else {} trans_norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), trans_norm, ]) transform_test = transforms.Compose([ transforms.Resize((32,32)), transforms.ToTensor(), trans_norm, ]) tr_data = CIFAR10(root='../../data', train=True, download=True, transform=transform_train) te_data = CIFAR10(root='../../data', train=False, download=True, transform=transform_test) ood_data = LSUN(root='../../data', classes='test', transform=transform_test) train_loader = torch.utils.data.DataLoader( tr_data, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader( te_data, batch_size=args.batch_size, shuffle=False, **kwargs) ood_loader = torch.utils.data.DataLoader( ood_data, batch_size=args.batch_size, shuffle=False, **kwargs) model = VGG('VGG19').to(device) model.load_state_dict(torch.load('checkpoint/ckpt.pth')['net']) test(args, model, device, test_loader) if args.from_approx_model == 0: output_samples = torch.load('./cifar10-vgg19rand-tr-samples.pt') # --------------- training approx --------- print('approximating ...') fmodel = VGG('VGG19').to(device) gmodel = VGG('VGG19', concentration=True).to(device) if args.from_approx_model == 0: f_optimizer = optim.SGD(fmodel.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) g_optimizer = optim.SGD(gmodel.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) best_acc = 0 for epoch in range(1, args.approx_epochs + 1): train_approx(args, fmodel, gmodel, device, train_loader, f_optimizer, g_optimizer, output_samples, epoch) acc = test(args, fmodel, device, test_loader) if acc > best_acc: torch.save(fmodel.state_dict(), args.model_path + 'cifar10rand-mean-mmd.pt') torch.save(gmodel.state_dict(), args.model_path + 'cifar10rand-conc-mmd.pt') best_acc = acc else: fmodel.load_state_dict(torch.load(args.model_path + 'cifar10rand-mean-mmd.pt')) gmodel.load_state_dict(torch.load(args.model_path + 'cifar10rand-conc-mmd.pt')) teacher_test_samples = torch.load('./cifar10-vgg19rand-te-samples.pt') teacher_ood_samples = torch.load('./cifar10-vgg19rand-lsun-samples.pt') # fitting individual Dirichlet is not in the sample code as it's time-consuming eval_approx(args, fmodel, gmodel, device, test_loader, ood_loader, teacher_test_samples, teacher_ood_samples)
for ind in indices: plot_data.append(data[ind]) plot_data = torch.stack(plot_data, dim=0) save_image(plot_data, '{}.png'.format(name), nrow=k + 1) if __name__ == '__main__': args = parser.parse_args() if args.dataset == 'church': transforms = Compose([ Resize(96), CenterCrop(96), ToTensor() ]) dataset = LSUN('exp/datasets/lsun', ['church_outdoor_train'], transform=transforms) elif args.dataset == 'tower' or args.dataset == 'bedroom': transforms = Compose([ Resize(128), CenterCrop(128), ToTensor() ]) dataset = LSUN('exp/datasets/lsun', ['{}_train'.format(args.dataset)], transform=transforms) elif args.dataset == 'celeba': transforms = Compose([ CenterCrop(140), Resize(64), ToTensor(), ])
def get_dataset(d_config, data_folder): cmp = lambda x: transforms.Compose([*x]) if d_config.dataset == 'CIFAR10': train_transform = [ transforms.Resize(d_config.image_size), transforms.ToTensor() ] test_transform = [ transforms.Resize(d_config.image_size), transforms.ToTensor() ] if d_config.random_flip: train_transform.insert(1, transforms.RandomHorizontalFlip()) path = os.path.join(data_folder, 'CIFAR10') dataset = CIFAR10(path, train=True, download=True, transform=cmp(train_transform)) test_dataset = CIFAR10(path, train=False, download=True, transform=cmp(test_transform)) elif d_config.dataset == 'CELEBA': train_transform = [ transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor() ] test_transform = [ transforms.CenterCrop(140), transforms.Resize(d_config.image_size), transforms.ToTensor() ] if d_config.random_flip: train_transform.insert(2, transforms.RandomHorizontalFlip()) path = os.path.join(data_folder, 'celeba') dataset = CelebA(path, split='train', transform=cmp(train_transform), download=True) test_dataset = CelebA(path, split='test', transform=cmp(test_transform), download=True) elif d_config.dataset == 'Stacked_MNIST': dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_train'), load=False, source_root=data_folder, train=True) test_dataset = Stacked_MNIST(root=os.path.join(data_folder, 'stackedmnist_test'), load=False, source_root=data_folder, train=False) elif d_config.dataset == 'LSUN': ims = d_config.image_size train_transform = [ transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor() ] test_transform = [ transforms.Resize(ims), transforms.CenterCrop(ims), transforms.ToTensor() ] if d_config.random_flip: train_transform.insert(2, transforms.RandomHorizontalFlip()) path = data_folder dataset = LSUN(path, classes=[d_config.category + "_train"], transform=cmp(train_transform)) test_dataset = LSUN(path, classes=[d_config.category + "_val"], transform=cmp(test_transform)) elif d_config.dataset == "FFHQ": train_transform = [transforms.ToTensor()] test_transform = [transforms.ToTensor()] if d_config.random_flip: train_transform.insert(0, transforms.RandomHorizontalFlip()) path = os.path.join(data_folder, 'FFHQ') dataset = FFHQ(path, transform=train_transform, resolution=d_config.image_size) test_dataset = FFHQ(path, transform=test_transform, resolution=d_config.image_size) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.9 )], indices[int(num_items * 0.9):] dataset = Subset(dataset, train_indices) test_dataset = Subset(test_dataset, test_indices) else: raise ValueError("Dataset [" + d_config.dataset + "] not configured.") return dataset, test_dataset
def get_dataset(args, config): if config.data.dataset == 'CIFAR10': if (config.data.random_flip): dataset = CIFAR10(os.path.join('datasets', 'cifar10'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ])) test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'), train=False, download=True, transform=transforms.Compose([ transforms.Resize( config.data.image_size), transforms.ToTensor() ])) else: dataset = CIFAR10(os.path.join('datasets', 'cifar10'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'), train=False, download=True, transform=transforms.Compose([ transforms.Resize( config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == 'CELEBA': if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif (config.data.dataset == "CELEBA-32px"): if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif (config.data.dataset == "CELEBA-8px"): if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif config.data.dataset == 'LSUN': train_folder = '{}_train'.format(config.data.category) val_folder = '{}_val'.format(config.data.category) if config.data.random_flip: dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), ])) else: dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) test_dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[val_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) elif config.data.dataset == "FFHQ": if config.data.random_flip: dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'), transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]), resolution=config.data.image_size) else: dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'), transform=transforms.ToTensor(), resolution=config.data.image_size) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.9 )], indices[int(num_items * 0.9):] test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) elif config.data.dataset == "MNIST": if config.data.random_flip: dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == "USPS": if config.data.random_flip: dataset = USPS(root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = USPS(root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = USPS(root=os.path.join('datasets', 'USPS'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == "USPS-Pad": if config.data.random_flip: dataset = USPS( root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = USPS( root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = USPS( root=os.path.join('datasets', 'USPS'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif (config.data.dataset.upper() == "GAUSSIAN"): if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) if (config.data.isotropic): dim = config.data.dim rank = config.data.rank cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)])) mean = np.zeros((dim, )) else: cov = np.array(config.data.cov) mean = np.array(config.data.mean) shape = config.data.dataset.shape if hasattr(config.data.dataset, "shape") else None dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape) test_dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape) elif (config.data.dataset.upper() == "GAUSSIAN-HD"): if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) cov = np.load(config.data.cov_path) mean = np.load(config.data.mean_path) dataset = Gaussian(device=args.device, cov=cov, mean=mean) test_dataset = Gaussian(device=args.device, cov=cov, mean=mean) elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"): # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality # of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow. if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) shape = config.data.shape if hasattr(config.data, "shape") else None dataset = Gaussian(device=args.device, mean=None, cov=None, shape=shape, iid_unit=True) test_dataset = Gaussian(device=args.device, mean=None, cov=None, shape=shape, iid_unit=True) return dataset, test_dataset
train_config = { 'num_iter': 1000000, 'batch_size': 128, 'image_size': 64, 'iter_per_tick': 1000, 'ticks_per_snapshot': 10, 'device': torch.device("cuda:0"), } run_name = 'LSUN_bedroom' lsun_dir = '../lsun/' result_dir = 'results/' if not os.path.isdir(result_dir): os.mkdir(result_dir) dataset = LSUN(root=lsun_dir, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Resize(train_config['image_size']), transforms.CenterCrop(train_config['image_size']), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) train_config['dataloader'] = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True, num_workers=0) train_session = TrainSession(run_name, result_dir) train_session.training_loop(**train_config)
def main(args): def maybe_cuda(tensor): return tensor.cuda() if args.gpu else tensor Img_dir = 'Figs/' if not os.path.exists(Img_dir): os.makedirs(Img_dir) Model_dir = 'Models/' if not os.path.exists(Model_dir): os.makedirs(Model_dir) Data_dir = 'Data/' if not os.path.exists(Data_dir): os.makedirs(Data_dir) train_set = utils.IndexedDataset( LSUN(args.dir, classes=[args.cl+'_train'], transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) ) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=8, pin_memory=args.gpu, ) val_loader = torch.utils.data.DataLoader(train_set, shuffle=False, batch_size=8*8) if args.n > 0: train_set.base.length = args.n train_set.base.indices = [args.n] # initialize representation space: if args.init == 'pca': print('Check if PCA is already calculated...') pca_path = 'Data/GLO_pca_init_{}_{}.pt'.format( args.cl, args.d) if os.path.isfile(pca_path): print( '[Latent Init] PCA already calculated before and saved at {}'. format(pca_path)) Z = torch.load(pca_path) else: from sklearn.decomposition import PCA # first, take a subset of train set to fit the PCA X_pca = np.vstack([ X.cpu().numpy().reshape(len(X), -1) for i, (X, _, _) in zip(tqdm(range(args.n_pca // train_loader.batch_size), 'collect data for PCA'), train_loader) ]) print("perform PCA...") pca = PCA(n_components=args.d) pca.fit(X_pca) # then, initialize latent vectors to the pca projections of the complete dataset Z = np.empty((len(train_loader.dataset), args.d)) for X, _, idx in tqdm(train_loader, 'pca projection'): Z[idx] = pca.transform(X.cpu().numpy().reshape(len(X), -1)) elif args.init == 'random': Z = np.random.randn(len(train_set), args.d) Z = utils.project_l2_ball(Z) model_generator = maybe_cuda(models.Generator(args.d)) loss_fn = utils.LapLoss(max_levels=3) if args.loss == 'lap_l1' else nn.MSELoss() zi = maybe_cuda(torch.zeros((args.batch_size, args.d))) zi = Variable(zi, requires_grad=True) optimizer = SGD([ {'params': model_generator.parameters(), 'lr': args.lr_g}, {'params': zi, 'lr': args.lr_z} ]) Xi_val, _, idx_val = next(iter(val_loader)) utils.imsave(Img_dir+'target_%s_%s.png' % (args.cl,args.prfx), make_grid(Xi_val.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0)) for epoch in range(args.e): losses = [] progress = tqdm(total=len(train_loader), desc='epoch % 3d' % epoch) for i, (Xi, yi, idx) in enumerate(train_loader): Xi = Variable(maybe_cuda(Xi)) zi.data = maybe_cuda(torch.FloatTensor(Z[idx.numpy()])) optimizer.zero_grad() rec = model_generator(zi) loss = loss_fn(rec, Xi) loss.backward() optimizer.step() Z[idx.numpy()] = utils.project_l2_ball(zi.data.cpu().numpy()) losses.append(loss.data[0]) progress.set_postfix({'loss': np.mean(losses[-100:])}) progress.update() progress.close() # visualize reconstructions rec = model_generator(Variable(maybe_cuda(torch.FloatTensor(Z[idx_val.numpy()])))) utils.imsave(Img_dir+'%s_%s_rec_epoch_%03d_%s.png' % (args.cl,args.prfx,epoch, args.init), make_grid(rec.data.cpu() / 2. + 0.5, nrow=8).numpy().transpose(1, 2, 0)) print('Saving the model : epoch % 3d'%epoch) utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model_generator.state_dict(), }, Model_dir + 'Glo_{}_z_{}_epch_{}_init_{}.pt'.format(args.cl,args.d, epoch, args.init))