def setup_data_loaders(args, use_cuda=False): if args.dataset == 'celeba': datasets = { 'train': dset.CelebA(mode='train'), 'validation': dset.CelebA(mode='validation', train=False), 'test': dset.CelebA(mode='test', train=False), } train_set = dset.CelebA(mode='train') else: raise ValueError('Unknown dataset ' + str(args.dataset)) kwargs = {'num_workers': 4, 'pin_memory': use_cuda} loaders = {k: DataLoader(dataset=v, batch_size=args.batch_size, shuffle=True, **kwargs) for k, v in datasets.items()} return loaders
def load_model_and_dataset(checkpt_filename): checkpt = torch.load(checkpt_filename) args = checkpt['args'] state_dict = checkpt['state_dict'] # backwards compatibility if not hasattr(args, 'conv'): args.conv = False x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli() a_dist = dist.Bernoulli() # model if args.dist == 'normal': prior_dist = dist.Normal() q_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() q_dist = dist.Laplace() elif args.dist == 'flow': prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) q_dist = dist.Normal() #vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss, n_chan=3 if args.dataset == 'celeba' else 1, sens_idx=SENS_IDX, x_dist=x_dist, a_dist=a_dist) vae.load_state_dict(state_dict, strict=False) vae.beta = args.beta vae.beta_sens = args.beta_sens vae.eval() # dataset loader loader = setup_data_loaders(args, use_cuda=True) # test loader test_set = dset.CelebA(mode='test') kwargs = {'num_workers': 4, 'pin_memory': True} test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False, **kwargs) return vae, loader, test_loader, args
def setup_data_loaders(args_dataset, batch_size, use_cuda=False, len_subset=0): if args_dataset == 'shapes': train_set = dset.Shapes(len_subset=len_subset) elif args_dataset == 'faces': train_set = dset.Faces() elif args_dataset == 'cars3d': train_set = dset.Cars3d() elif args_dataset == 'celeba': train_set = dset.CelebA() elif args_dataset == '3dchairs': train_set = dset.Chairs() else: raise ValueError('Unknown dataset ' + str(args_dataset)) kwargs = {'num_workers': 4, 'pin_memory': use_cuda} train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, **kwargs) return train_loader
def main(): # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument( '-d', '--dataset', default='shapes', type=str, help='dataset name', choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs']) parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'lpnorm', 'lpnested']) parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs') parser.add_argument( '--num-iterations', default=0, type=int, help='number of iterations (overrides number of epochs if >0)') parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size') parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate') parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension') parser.add_argument('-p', '--pnorm', default=4.0 / 3.0, type=float, help='p value of the Lp-norm') parser.add_argument( '--pnested', default='', type=str, help= 'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]' ) parser.add_argument( '--isa', default='', type=str, help= 'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]' ) parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA') parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA') parser.add_argument('--n1', default=6, type=int, help='n1 of ISA') parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA') parser.add_argument('--n2', default=6, type=int, help='n2 of ISA') parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA') parser.add_argument('--n3', default=6, type=int, help='n3 of ISA') parser.add_argument('--scale', default=1.0, type=float, help='scale of LpNested distribution') parser.add_argument('--q-dist', default='normal', type=str, choices=['normal', 'laplace']) parser.add_argument('--x-dist', default='bernoulli', type=str, choices=['normal', 'bernoulli']) parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term') parser.add_argument('--tcvae', action='store_true') parser.add_argument('--exclude-mutinfo', action='store_true') parser.add_argument('--beta-anneal', action='store_true') parser.add_argument('--lambda-anneal', action='store_true') parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator') parser.add_argument('--conv', action='store_true') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired') parser.add_argument('--save', default='test1') parser.add_argument('--id', default='1') parser.add_argument( '--seed', default=-1, type=int, help= 'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)' ) parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log') parser.add_argument('--use-mse-loss', action='store_true') parser.add_argument('--mse-sigma', default=0.01, type=float, help='sigma of mean squared error loss') parser.add_argument('--dip', action='store_true', help='use DIP-VAE') parser.add_argument('--dip-type', default=1, type=int, help='DIP type (1 or 2)') parser.add_argument('--lambda-od', default=2.0, type=float, help='DIP: lambda weight off-diagonal') parser.add_argument('--clip', default=0.0, type=float, help='Gradient clipping (0 disabled)') parser.add_argument('--test', action='store_true', help='run test') parser.add_argument( '--trainingsetsize', default=0, type=int, help='Subsample the trainingset (0 use original training data)') args = parser.parse_args() # initialize seeds for reproducibility if not args.seed == -1: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if not os.path.exists(args.save): os.makedirs(args.save) if args.gpu != -1: print('Using CUDA device {}'.format(args.gpu)) torch.cuda.set_device(args.gpu) use_cuda = True else: print('CUDA disabled') use_cuda = False # data loader train_loader = setup_data_loaders(args.dataset, args.batch_size, use_cuda=use_cuda, len_subset=args.trainingsetsize) # setup the VAE if args.dist == 'normal': prior_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() elif args.dist == 'lpnested': if not args.isa == '': pnested = parseISA(ast.literal_eval(args.isa)) elif not args.pnested == '': pnested = ast.literal_eval(args.pnested) else: pnested = parseISA([ args.p0, [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)] ]) print('using Lp-nested prior, pnested = ({}) {}'.format( type(pnested), pnested)) prior_dist = LpNestedAdapter(p=pnested, scale=args.scale) args.latent_dim = prior_dist.dimz() print('using Lp-nested prior, changed latent dimension to {}'.format( args.latent_dim)) elif args.dist == 'lpnorm': prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim], scale=args.scale) if args.q_dist == 'normal': q_dist = dist.Normal() elif args.q_dist == 'laplace': q_dist = dist.Laplace() if args.x_dist == 'normal': x_dist = dist.Normal(sigma=args.mse_sigma) elif args.x_dist == 'bernoulli': x_dist = dist.Bernoulli() if args.dip_type == 1: lambda_d = 10.0 * args.lambda_od else: lambda_d = args.lambda_od vae = VAE(z_dim=args.latent_dim, use_cuda=use_cuda, prior_dist=prior_dist, q_dist=q_dist, x_dist=x_dist, include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss, dataset=args.dataset, mse_sigma=args.mse_sigma, DIP=args.dip, DIP_type=args.dip_type, lambda_od=args.lambda_od, lambda_d=lambda_d) # setup the optimizer optimizer = optim.Adam([{ 'params': vae.parameters() }, { 'params': prior_dist.parameters(), 'lr': 5e-4 }], lr=args.learning_rate) # setup visdom for visualization if args.visdom: vis = visdom.Visdom(env=args.save, port=4500) train_elbo = [] # training loop dataset_size = len(train_loader.dataset) if args.num_iterations == 0: num_iterations = len(train_loader) * args.num_epochs else: num_iterations = args.num_iterations iteration = 0 obj_best_snapshot = float('-inf') best_checkpoint_updated = False trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv') if not os.path.exists(trainingcurve_filename): with open(trainingcurve_filename, 'w') as fd: fd.write( 'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n' ) # initialize loss accumulator elbo_running_mean = utils.RunningAverageMeter() nan_detected = False while iteration < num_iterations and not nan_detected: for i, x in enumerate(train_loader): iteration += 1 batch_time = time.time() vae.train() anneal_kl(args, vae, iteration) optimizer.zero_grad() # transfer to GPU if use_cuda: x = x.cuda() # async=True) # wrap the mini-batch in a PyTorch Variable x = Variable(x) # do ELBO gradient and accumulate loss #with autograd.detect_anomaly(): obj, elbo, logpx = vae.elbo(prior_dist, x, dataset_size, use_mse_loss=args.use_mse_loss, mse_sigma=args.mse_sigma) if utils.isnan(obj).any(): print('NaN spotted in objective.') print('lpnested: {}'.format(prior_dist.prior.p)) print("gradient abs max {}".format( max([g.abs().max() for g in gradients]))) #raise ValueError('NaN spotted in objective.') nan_detected = True break elbo_running_mean.update(elbo.mean().item()) # save checkpoint of best ELBO if obj.mean().item() > obj_best_snapshot: obj_best_snapshot = obj.mean().item() best_checkpoint = { 'state_dict': vae.state_dict(), 'state_dict_prior_dist': prior_dist.state_dict(), 'args': args, 'iteration': iteration, 'obj': obj_best_snapshot, 'elbo': elbo.mean().item(), 'logpx': logpx.mean().item() } best_checkpoint_updated = True #with autograd.detect_anomaly(): obj.mean().mul(-1).backward() gradients = list( filter(lambda p: p.grad is not None, vae.parameters())) if args.clip > 0: torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip) optimizer.step() # report training diagnostics if iteration % args.log_freq == 0: train_elbo.append(elbo_running_mean.avg) time_ = time.time() - batch_time print( '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)' % (iteration, num_iterations, time_, vae.beta, vae.lamb, obj.mean().item(), logpx.mean().item(), elbo_running_mean.val, elbo_running_mean.avg)) p0, p1list = backwardsParseISA(prior_dist.prior.p) print('lpnested: {}, {}'.format(p0, p1list)) print("gradient abs max {}".format( max([g.abs().max() for g in gradients]))) with open(os.path.join(args.save, 'trainingcurve.csv'), 'a') as fd: fd.write('{},{},{},{},{}\n'.format(iteration, num_iterations, time_, elbo_running_mean.val, elbo_running_mean.avg)) if best_checkpoint_updated: print( 'Update best checkpoint [iteration %03d] training ELBO: %.4f' % (best_checkpoint['iteration'], best_checkpoint['elbo'])) utils.save_checkpoint(best_checkpoint, args.save, 0) best_checkpoint_updated = False vae.eval() prior_dist.eval() # plot training and test ELBOs if args.visdom: if args.dataset == 'celeba': num_channels = 3 else: num_channels = 1 display_samples(vae, prior_dist, x, vis, num_channels) plot_elbo(train_elbo, vis) if iteration % (10 * args.log_freq) == 0: utils.save_checkpoint( { 'state_dict': vae.state_dict(), 'state_dict_prior_dist': prior_dist.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'args': args, 'iteration': iteration, 'obj': obj.mean().item(), 'torch_random_state': torch.get_rng_state(), 'numpy_random_state': np.random.get_state() }, args.save, prefix='latest-optimizer-model-') if not args.dataset == 'celeba' and not args.dataset == '3dchairs': eval('plot_vs_gt_' + args.dataset)( vae, train_loader.dataset, os.path.join( args.save, 'gt_vs_latent_{:05d}.png'.format(iteration))) # Report statistics of best snapshot after training vae.load_state_dict(best_checkpoint['state_dict']) prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist']) vae.eval() prior_dist.eval() if args.dataset == 'shapes': data_set = dset.Shapes() elif args.dataset == 'faces': data_set = dset.Faces() elif args.dataset == 'cars3d': data_set = dset.Cars3d() elif args.dataset == 'celeba': data_set = dset.CelebA() elif args.dataset == '3dchairs': data_set = dset.Chairs() else: raise ValueError('Unknown dataset ' + str(args.dataset)) print("loaded dataset {} of size {}".format(args.dataset, len(data_set))) dataset_loader = DataLoader(data_set, batch_size=1000, num_workers=0, shuffle=False) logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \ elbo_decomposition(vae, prior_dist, dataset_loader) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'marginal_entropies': elbo_marginal_entropies, 'joint_entropy': elbo_joint_entropy }, os.path.join(args.save, 'elbo_decomposition.pth')) print('logpx: {:.2f}'.format(logpx)) if not args.dataset == 'celeba' and not args.dataset == '3dchairs': eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join( args.save, 'gt_vs_latent.png')) metric, metric_marginal_entropies, metric_cond_entropies = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)( vae, dataset_loader.dataset) torch.save( { 'args': args, 'metric': metric, 'marginal_entropies': metric_marginal_entropies, 'cond_entropies': metric_cond_entropies, }, os.path.join(args.save, 'disentanglement_metric.pth')) print('MIG: {:.2f}'.format(metric)) if args.dist == 'lpnested': p0, p1list = backwardsParseISA(prior_dist.prior.p) print('p0: {}'.format(p0)) print('p1: {}'.format(p1list)) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, 'p0': p0, 'p1': p1list }, os.path.join(args.save, 'combined_data.pth')) else: torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, }, os.path.join(args.save, 'combined_data.pth')) if args.dist == 'lpnested': if args.dataset == 'shapes': eval('plot_vs_gt_' + args.dataset)( vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_grouped_latent.png'), eval_subspaces=True) metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)(vae, dataset_loader.dataset, eval_subspaces=True) torch.save( { 'args': args, 'metric': metric_subspaces, 'marginal_entropies': metric_marginal_entropies_subspaces, 'cond_entropies': metric_cond_entropies_subspaces, }, os.path.join(args.save, 'disentanglement_metric_subspaces.pth')) print('MIG grouped by subspaces: {:.2f}'.format( metric_subspaces)) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, 'metric_subspaces': metric_subspaces, 'metric_marginal_entropies_subspaces': metric_marginal_entropies_subspaces, 'metric_cond_entropies_subspaces': metric_cond_entropies_subspaces, 'p0': p0, 'p1': p1list }, os.path.join(args.save, 'combined_data.pth')) return vae