def main(): # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-d', '--dataset', default='shapes', type=str, help='dataset name', choices=['shapes', 'faces']) parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow']) parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs') parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size') parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate') parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension') parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term') parser.add_argument('--tcvae', action='store_true') parser.add_argument('--exclude-mutinfo', action='store_true') parser.add_argument('--beta-anneal', action='store_true') parser.add_argument('--lambda-anneal', action='store_true') parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator') parser.add_argument('--conv', action='store_true') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired') parser.add_argument('--save', default='test1') parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log') args = parser.parse_args() # torch.cuda.set_device(args.gpu) # data loader train_loader = setup_data_loaders(args, use_cuda=True) # setup the VAE if args.dist == 'normal': prior_dist = dist.Normal() q_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() q_dist = dist.Laplace() elif args.dist == 'flow': prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) q_dist = dist.Normal() vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss) # setup the optimizer optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate) # setup visdom for visualization if args.visdom: vis = visdom.Visdom(env=args.save, port=4500) train_elbo = [] # training loop dataset_size = len(train_loader.dataset) num_iterations = len(train_loader) * args.num_epochs iteration = 0 # initialize loss accumulator elbo_running_mean = utils.RunningAverageMeter() while iteration < num_iterations: for i, x in enumerate(train_loader): iteration += 1 batch_time = time.time() vae.train() anneal_kl(args, vae, iteration) optimizer.zero_grad() # transfer to GPU x = x.cuda(async=True) # wrap the mini-batch in a PyTorch Variable x = Variable(x) # do ELBO gradient and accumulate loss obj, elbo = vae.elbo(x, dataset_size) if utils.isnan(obj).any(): raise ValueError('NaN spotted in objective.') obj.mean().mul(-1).backward() print("obj value: ", obj.mean().mul(-1).cpu()) elbo_running_mean.update(elbo.mean().item()) optimizer.step() # report training diagnostics if iteration % args.log_freq == 0: train_elbo.append(elbo_running_mean.avg) print('[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f)' % ( iteration, time.time() - batch_time, vae.beta, vae.lamb, elbo_running_mean.val, elbo_running_mean.avg)) vae.eval() # plot training and test ELBOs if args.visdom: display_samples(vae, x, vis) plot_elbo(train_elbo, vis) utils.save_checkpoint({ 'state_dict': vae.state_dict(), 'args': args}, args.save, 0) eval('plot_vs_gt_' + args.dataset)(vae, train_loader.dataset, os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration))) # Report statistics after training vae.eval() utils.save_checkpoint({ 'state_dict': vae.state_dict(), 'args': args}, args.save, 0) dataset_loader = DataLoader(train_loader.dataset, batch_size=10, num_workers=1, shuffle=False) logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \ elbo_decomposition(vae, dataset_loader) torch.save({ 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'marginal_entropies': marginal_entropies, 'joint_entropy': joint_entropy }, os.path.join(args.save, 'elbo_decomposition.pth')) eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png')) return vae
def main(): # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument( '-d', '--dataset', default='shapes', type=str, help='dataset name', choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs']) parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'lpnorm', 'lpnested']) parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs') parser.add_argument( '--num-iterations', default=0, type=int, help='number of iterations (overrides number of epochs if >0)') parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size') parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate') parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension') parser.add_argument('-p', '--pnorm', default=4.0 / 3.0, type=float, help='p value of the Lp-norm') parser.add_argument( '--pnested', default='', type=str, help= 'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]' ) parser.add_argument( '--isa', default='', type=str, help= 'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]' ) parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA') parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA') parser.add_argument('--n1', default=6, type=int, help='n1 of ISA') parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA') parser.add_argument('--n2', default=6, type=int, help='n2 of ISA') parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA') parser.add_argument('--n3', default=6, type=int, help='n3 of ISA') parser.add_argument('--scale', default=1.0, type=float, help='scale of LpNested distribution') parser.add_argument('--q-dist', default='normal', type=str, choices=['normal', 'laplace']) parser.add_argument('--x-dist', default='bernoulli', type=str, choices=['normal', 'bernoulli']) parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term') parser.add_argument('--tcvae', action='store_true') parser.add_argument('--exclude-mutinfo', action='store_true') parser.add_argument('--beta-anneal', action='store_true') parser.add_argument('--lambda-anneal', action='store_true') parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator') parser.add_argument('--conv', action='store_true') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired') parser.add_argument('--save', default='test1') parser.add_argument('--id', default='1') parser.add_argument( '--seed', default=-1, type=int, help= 'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)' ) parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log') parser.add_argument('--use-mse-loss', action='store_true') parser.add_argument('--mse-sigma', default=0.01, type=float, help='sigma of mean squared error loss') parser.add_argument('--dip', action='store_true', help='use DIP-VAE') parser.add_argument('--dip-type', default=1, type=int, help='DIP type (1 or 2)') parser.add_argument('--lambda-od', default=2.0, type=float, help='DIP: lambda weight off-diagonal') parser.add_argument('--clip', default=0.0, type=float, help='Gradient clipping (0 disabled)') parser.add_argument('--test', action='store_true', help='run test') parser.add_argument( '--trainingsetsize', default=0, type=int, help='Subsample the trainingset (0 use original training data)') args = parser.parse_args() # initialize seeds for reproducibility if not args.seed == -1: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if not os.path.exists(args.save): os.makedirs(args.save) if args.gpu != -1: print('Using CUDA device {}'.format(args.gpu)) torch.cuda.set_device(args.gpu) use_cuda = True else: print('CUDA disabled') use_cuda = False # data loader train_loader = setup_data_loaders(args.dataset, args.batch_size, use_cuda=use_cuda, len_subset=args.trainingsetsize) # setup the VAE if args.dist == 'normal': prior_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() elif args.dist == 'lpnested': if not args.isa == '': pnested = parseISA(ast.literal_eval(args.isa)) elif not args.pnested == '': pnested = ast.literal_eval(args.pnested) else: pnested = parseISA([ args.p0, [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)] ]) print('using Lp-nested prior, pnested = ({}) {}'.format( type(pnested), pnested)) prior_dist = LpNestedAdapter(p=pnested, scale=args.scale) args.latent_dim = prior_dist.dimz() print('using Lp-nested prior, changed latent dimension to {}'.format( args.latent_dim)) elif args.dist == 'lpnorm': prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim], scale=args.scale) if args.q_dist == 'normal': q_dist = dist.Normal() elif args.q_dist == 'laplace': q_dist = dist.Laplace() if args.x_dist == 'normal': x_dist = dist.Normal(sigma=args.mse_sigma) elif args.x_dist == 'bernoulli': x_dist = dist.Bernoulli() if args.dip_type == 1: lambda_d = 10.0 * args.lambda_od else: lambda_d = args.lambda_od vae = VAE(z_dim=args.latent_dim, use_cuda=use_cuda, prior_dist=prior_dist, q_dist=q_dist, x_dist=x_dist, include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss, dataset=args.dataset, mse_sigma=args.mse_sigma, DIP=args.dip, DIP_type=args.dip_type, lambda_od=args.lambda_od, lambda_d=lambda_d) # setup the optimizer optimizer = optim.Adam([{ 'params': vae.parameters() }, { 'params': prior_dist.parameters(), 'lr': 5e-4 }], lr=args.learning_rate) # setup visdom for visualization if args.visdom: vis = visdom.Visdom(env=args.save, port=4500) train_elbo = [] # training loop dataset_size = len(train_loader.dataset) if args.num_iterations == 0: num_iterations = len(train_loader) * args.num_epochs else: num_iterations = args.num_iterations iteration = 0 obj_best_snapshot = float('-inf') best_checkpoint_updated = False trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv') if not os.path.exists(trainingcurve_filename): with open(trainingcurve_filename, 'w') as fd: fd.write( 'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n' ) # initialize loss accumulator elbo_running_mean = utils.RunningAverageMeter() nan_detected = False while iteration < num_iterations and not nan_detected: for i, x in enumerate(train_loader): iteration += 1 batch_time = time.time() vae.train() anneal_kl(args, vae, iteration) optimizer.zero_grad() # transfer to GPU if use_cuda: x = x.cuda() # async=True) # wrap the mini-batch in a PyTorch Variable x = Variable(x) # do ELBO gradient and accumulate loss #with autograd.detect_anomaly(): obj, elbo, logpx = vae.elbo(prior_dist, x, dataset_size, use_mse_loss=args.use_mse_loss, mse_sigma=args.mse_sigma) if utils.isnan(obj).any(): print('NaN spotted in objective.') print('lpnested: {}'.format(prior_dist.prior.p)) print("gradient abs max {}".format( max([g.abs().max() for g in gradients]))) #raise ValueError('NaN spotted in objective.') nan_detected = True break elbo_running_mean.update(elbo.mean().item()) # save checkpoint of best ELBO if obj.mean().item() > obj_best_snapshot: obj_best_snapshot = obj.mean().item() best_checkpoint = { 'state_dict': vae.state_dict(), 'state_dict_prior_dist': prior_dist.state_dict(), 'args': args, 'iteration': iteration, 'obj': obj_best_snapshot, 'elbo': elbo.mean().item(), 'logpx': logpx.mean().item() } best_checkpoint_updated = True #with autograd.detect_anomaly(): obj.mean().mul(-1).backward() gradients = list( filter(lambda p: p.grad is not None, vae.parameters())) if args.clip > 0: torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip) optimizer.step() # report training diagnostics if iteration % args.log_freq == 0: train_elbo.append(elbo_running_mean.avg) time_ = time.time() - batch_time print( '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)' % (iteration, num_iterations, time_, vae.beta, vae.lamb, obj.mean().item(), logpx.mean().item(), elbo_running_mean.val, elbo_running_mean.avg)) p0, p1list = backwardsParseISA(prior_dist.prior.p) print('lpnested: {}, {}'.format(p0, p1list)) print("gradient abs max {}".format( max([g.abs().max() for g in gradients]))) with open(os.path.join(args.save, 'trainingcurve.csv'), 'a') as fd: fd.write('{},{},{},{},{}\n'.format(iteration, num_iterations, time_, elbo_running_mean.val, elbo_running_mean.avg)) if best_checkpoint_updated: print( 'Update best checkpoint [iteration %03d] training ELBO: %.4f' % (best_checkpoint['iteration'], best_checkpoint['elbo'])) utils.save_checkpoint(best_checkpoint, args.save, 0) best_checkpoint_updated = False vae.eval() prior_dist.eval() # plot training and test ELBOs if args.visdom: if args.dataset == 'celeba': num_channels = 3 else: num_channels = 1 display_samples(vae, prior_dist, x, vis, num_channels) plot_elbo(train_elbo, vis) if iteration % (10 * args.log_freq) == 0: utils.save_checkpoint( { 'state_dict': vae.state_dict(), 'state_dict_prior_dist': prior_dist.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'args': args, 'iteration': iteration, 'obj': obj.mean().item(), 'torch_random_state': torch.get_rng_state(), 'numpy_random_state': np.random.get_state() }, args.save, prefix='latest-optimizer-model-') if not args.dataset == 'celeba' and not args.dataset == '3dchairs': eval('plot_vs_gt_' + args.dataset)( vae, train_loader.dataset, os.path.join( args.save, 'gt_vs_latent_{:05d}.png'.format(iteration))) # Report statistics of best snapshot after training vae.load_state_dict(best_checkpoint['state_dict']) prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist']) vae.eval() prior_dist.eval() if args.dataset == 'shapes': data_set = dset.Shapes() elif args.dataset == 'faces': data_set = dset.Faces() elif args.dataset == 'cars3d': data_set = dset.Cars3d() elif args.dataset == 'celeba': data_set = dset.CelebA() elif args.dataset == '3dchairs': data_set = dset.Chairs() else: raise ValueError('Unknown dataset ' + str(args.dataset)) print("loaded dataset {} of size {}".format(args.dataset, len(data_set))) dataset_loader = DataLoader(data_set, batch_size=1000, num_workers=0, shuffle=False) logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \ elbo_decomposition(vae, prior_dist, dataset_loader) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'marginal_entropies': elbo_marginal_entropies, 'joint_entropy': elbo_joint_entropy }, os.path.join(args.save, 'elbo_decomposition.pth')) print('logpx: {:.2f}'.format(logpx)) if not args.dataset == 'celeba' and not args.dataset == '3dchairs': eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join( args.save, 'gt_vs_latent.png')) metric, metric_marginal_entropies, metric_cond_entropies = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)( vae, dataset_loader.dataset) torch.save( { 'args': args, 'metric': metric, 'marginal_entropies': metric_marginal_entropies, 'cond_entropies': metric_cond_entropies, }, os.path.join(args.save, 'disentanglement_metric.pth')) print('MIG: {:.2f}'.format(metric)) if args.dist == 'lpnested': p0, p1list = backwardsParseISA(prior_dist.prior.p) print('p0: {}'.format(p0)) print('p1: {}'.format(p1list)) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, 'p0': p0, 'p1': p1list }, os.path.join(args.save, 'combined_data.pth')) else: torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, }, os.path.join(args.save, 'combined_data.pth')) if args.dist == 'lpnested': if args.dataset == 'shapes': eval('plot_vs_gt_' + args.dataset)( vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_grouped_latent.png'), eval_subspaces=True) metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)(vae, dataset_loader.dataset, eval_subspaces=True) torch.save( { 'args': args, 'metric': metric_subspaces, 'marginal_entropies': metric_marginal_entropies_subspaces, 'cond_entropies': metric_cond_entropies_subspaces, }, os.path.join(args.save, 'disentanglement_metric_subspaces.pth')) print('MIG grouped by subspaces: {:.2f}'.format( metric_subspaces)) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, 'metric_subspaces': metric_subspaces, 'metric_marginal_entropies_subspaces': metric_marginal_entropies_subspaces, 'metric_cond_entropies_subspaces': metric_cond_entropies_subspaces, 'p0': p0, 'p1': p1list }, os.path.join(args.save, 'combined_data.pth')) return vae
def evaluate(args, outputdir, vae, dataset, prefix=''): if os.path.exists(os.path.join(outputdir, prefix + 'combined_data.pth')): return # Report statistics vae.eval() dataset_loader = DataLoader(dataset, batch_size=1000, num_workers=0, shuffle=True) if not os.path.exists(outputdir): os.makedirs(outputdir) logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \ elbo_decomposition(vae, dataset_loader) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'marginal_entropies': elbo_marginal_entropies, 'joint_entropy': elbo_joint_entropy }, os.path.join(outputdir, prefix + 'elbo_decomposition.pth')) metric, metric_marginal_entropies, metric_cond_entropies = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)( vae, dataset_loader.dataset, eval_subspaces=False) torch.save( { 'args': args, 'metric': metric, 'marginal_entropies': metric_marginal_entropies, 'cond_entropies': metric_cond_entropies, }, os.path.join(outputdir, prefix + 'disentanglement_metric.pth')) print('logpx: {:.2f}'.format(logpx)) print('MIG: {:.2f}'.format(metric)) eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join( outputdir, prefix + 'gt_vs_latent.png')) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, }, os.path.join(outputdir, prefix + 'combined_data.pth')) if args.dist == 'lpnested': eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join( outputdir, 'gt_vs_grouped_latent.png'), eval_subspaces=True) metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval( 'disentanglement_metrics.mutual_info_metric_' + args.dataset)( vae, dataset_loader.dataset, eval_subspaces=True) torch.save( { 'args': args, 'metric': metric_subspaces, 'marginal_entropies': metric_marginal_entropies_subspaces, 'cond_entropies': metric_cond_entropies_subspaces, }, os.path.join(outputdir, 'disentanglement_metric_subspaces.pth')) print('MIG grouped by subspaces: {:.2f}'.format(metric_subspaces)) torch.save( { 'args': args, 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'elbo_marginal_entropies': elbo_marginal_entropies, 'elbo_joint_entropy': elbo_joint_entropy, 'metric': metric, 'metric_marginal_entropies': metric_marginal_entropies, 'metric_cond_entropies': metric_cond_entropies, 'metric_subspaces': metric_subspaces, 'metric_marginal_entropies_subspaces': metric_marginal_entropies_subspaces, 'metric_cond_entropies_subspaces': metric_cond_entropies_subspaces }, os.path.join(outputdir, 'combined_data.pth'))
def load_model_and_dataset(checkpt_filename): print('Loading model and dataset.') checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage) args = checkpt['args'] state_dict = checkpt['state_dict'] # model if not hasattr(args, 'dist') or args.dist == 'normal': prior_dist = dist.Normal() q_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() q_dist = dist.Laplace() elif args.dist == 'flow': prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=4) q_dist = dist.Normal() vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) vae.load_state_dict(state_dict, strict=False) # dataset loader loader = setup_data_loaders(args) return vae, loader, args z_inds = list(map(int, args.zs.split(','))) if args.zs is not None else None torch.cuda.set_device(args.gpu) vae, dataset_loader, cpargs = load_model_and_dataset(args.checkpt) if args.elbo_decomp: elbo_decomposition(vae, dataset_loader) eval('plot_vs_gt_' + cpargs.dataset)(vae, dataset_loader.dataset, args.save, z_inds)
def main(): # parse command line arguments parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-d', '--dataset', default='celeba', type=str, help='dataset name', choices=['celeba']) parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow']) parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs') parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size') parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate') parser.add_argument('-z', '--latent-dim', default=100, type=int, help='size of latent dimension') parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term') parser.add_argument('--beta_sens', default=20, type=float, help='Relative importance of predicting sensitive attributes') #parser.add_argument('--sens_idx', default=[13, 15, 20], type=list, help='Relative importance of predicting sensitive attributes') parser.add_argument('--tcvae', action='store_true') parser.add_argument('--exclude-mutinfo', action='store_true') parser.add_argument('--beta-anneal', action='store_true') parser.add_argument('--lambda-anneal', action='store_true') parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator') parser.add_argument('--conv', action='store_true') parser.add_argument('--clf_samps', action='store_true') parser.add_argument('--clf_means', action='store_false', dest='clf_samps') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired') parser.add_argument('--save', default='betatcvae-celeba') parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log') parser.add_argument('--audit', action='store_true', help='after each epoch, audit the repr wrt fair clf task') args = parser.parse_args() print(args) if not os.path.exists(args.save): os.makedirs(args.save) writer = SummaryWriter(args.save) writer.add_text('args', json.dumps(vars(args), sort_keys=True, indent=4)) log_file = os.path.join(args.save, 'train.log') if os.path.exists(log_file): os.remove(log_file) print(vars(args)) print(vars(args), file=open(log_file, 'w')) torch.cuda.set_device(args.gpu) # data loader loaders = setup_data_loaders(args, use_cuda=True) # setup the VAE if args.dist == 'normal': prior_dist = dist.Normal() q_dist = dist.Normal() elif args.dist == 'laplace': prior_dist = dist.Laplace() q_dist = dist.Laplace() elif args.dist == 'flow': prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) q_dist = dist.Normal() x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli() a_dist = dist.Bernoulli() vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss, n_chan=3 if args.dataset == 'celeba' else 1, sens_idx=SENS_IDX, x_dist=x_dist, a_dist=a_dist, clf_samps=args.clf_samps) if args.audit: audit_label_fn = get_label_fn( dict(data=dict(name='celeba', label_fn='H')) ) audit_repr_fns = dict() audit_attr_fns = dict() audit_models = dict() audit_train_metrics = dict() audit_validation_metrics = dict() for attr_fn_name in CELEBA_SENS_IDX.keys(): model = MLPClassifier(args.latent_dim, 1000, 2) model.cuda() audit_models[attr_fn_name] = model audit_repr_fns[attr_fn_name] = get_repr_fn( dict(data=dict( name='celeba', repr_fn='remove_all', attr_fn=attr_fn_name)) ) audit_attr_fns[attr_fn_name] = get_attr_fn( dict(data=dict(name='celeba', attr_fn=attr_fn_name)) ) # setup the optimizer optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate) if args.audit: Adam = optim.Adam audit_optimizers = dict() for k, v in audit_models.items(): audit_optimizers[k] = Adam(v.parameters(), lr=args.learning_rate) # setup visdom for visualization if args.visdom: vis = visdom.Visdom(env=args.save, port=3776) train_elbo = [] train_tc = [] # training loop dataset_size = len(loaders['train'].dataset) num_iterations = len(loaders['train']) * args.num_epochs iteration = 0 # initialize loss accumulator elbo_running_mean = utils.RunningAverageMeter() tc_running_mean = utils.RunningAverageMeter() clf_acc_meters = {'clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx} val_elbo_running_mean = utils.RunningAverageMeter() val_tc_running_mean = utils.RunningAverageMeter() val_clf_acc_meters = {'val_clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx} while iteration < num_iterations: bar = tqdm(range(len(loaders['train']))) for i, (x, a) in enumerate(loaders['train']): bar.update() iteration += 1 batch_time = time.time() vae.train() #anneal_kl(args, vae, iteration) # TODO try annealing beta/beta_sens vae.beta = args.beta vae.beta_sens = args.beta_sens optimizer.zero_grad() # transfer to GPU x = x.cuda(async=True) a = a.float() a = a.cuda(async=True) # wrap the mini-batch in a PyTorch Variable x = Variable(x) a = Variable(a) # do ELBO gradient and accumulate loss obj, elbo, metrics = vae.elbo(x, a, dataset_size) if utils.isnan(obj).any(): raise ValueError('NaN spotted in objective.') obj.mean().mul(-1).backward() elbo_running_mean.update(elbo.mean().data.item()) tc_running_mean.update(metrics['tc']) for (s, meter), (_, acc) in zip(clf_acc_meters.items(), metrics.items()): clf_acc_meters[s].update(acc.data.item()) optimizer.step() if args.audit: for model in audit_models.values(): model.train() # now re-encode x and take a step to train each audit classifier for opt in audit_optimizers.values(): opt.zero_grad() with torch.no_grad(): zs, z_params = vae.encode(x) if args.clf_samps: z = zs else: z_mu = z_params.select(-1, 0) z = z_mu a_all = a for subgroup, model in audit_models.items(): # noise out sensitive dims of latent code z_ = z.clone() a_all_ = a_all.clone() # subsample to just sens attr of interest for this subgroup a_ = audit_attr_fns[subgroup](a_all_) # noise out sensitive dims for this subgroup z_ = audit_repr_fns[subgroup](z_, None, None) y_ = audit_label_fn(a_all_).long() loss, _, metrics = model(z_, y_, a_) loss.backward() audit_optimizers[subgroup].step() metrics_dict = {} metrics_dict.update(loss=loss.detach().item()) for k, v in metrics.items(): if v.numel() > 1: k += '-avg' v = v.float().mean() metrics_dict.update({k:v.detach().item()}) audit_train_metrics[subgroup] = metrics_dict # report training diagnostics if iteration % args.log_freq == 0: if args.audit: for subgroup, metrics in audit_train_metrics.items(): for metric_name, metric_value in metrics.items(): writer.add_scalar( '{}/{}'.format(subgroup, metric_name), metric_value, iteration) train_elbo.append(elbo_running_mean.avg) writer.add_scalar('train_elbo', elbo_running_mean.avg, iteration) train_tc.append(tc_running_mean.avg) writer.add_scalar('train_tc', tc_running_mean.avg, iteration) msg = '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) training TC %.4f (%.4f)' % ( iteration, time.time() - batch_time, vae.beta, vae.lamb, elbo_running_mean.val, elbo_running_mean.avg, tc_running_mean.val, tc_running_mean.avg) for k, v in clf_acc_meters.items(): msg += ' {}: {:.2f}'.format(k, v.avg) writer.add_scalar(k, v.avg, iteration) print(msg) print(msg, file=open(log_file, 'a')) vae.eval() ################################################################ # evaluate validation metrics on vae and auditors for x, a in loaders['validation']: # transfer to GPU x = x.cuda(async=True) a = a.float() a = a.cuda(async=True) # wrap the mini-batch in a PyTorch Variable x = Variable(x) a = Variable(a) # do ELBO gradient and accumulate loss obj, elbo, metrics = vae.elbo(x, a, dataset_size) if utils.isnan(obj).any(): raise ValueError('NaN spotted in objective.') # val_elbo_running_mean.update(elbo.mean().data.item()) val_tc_running_mean.update(metrics['tc']) for (s, meter), (_, acc) in zip( val_clf_acc_meters.items(), metrics.items()): val_clf_acc_meters[s].update(acc.data.item()) if args.audit: for model in audit_models.values(): model.eval() with torch.no_grad(): zs, z_params = vae.encode(x) if args.clf_samps: z = zs else: z_mu = z_params.select(-1, 0) z = z_mu a_all = a for subgroup, model in audit_models.items(): # noise out sensitive dims of latent code z_ = z.clone() a_all_ = a_all.clone() # subsample to just sens attr of interest for this subgroup a_ = audit_attr_fns[subgroup](a_all_) # noise out sensitive dims for this subgroup z_ = audit_repr_fns[subgroup](z_, None, None) y_ = audit_label_fn(a_all_).long() loss, _, metrics = model(z_, y_, a_) loss.backward() audit_optimizers[subgroup].step() metrics_dict = {} metrics_dict.update(val_loss=loss.detach().item()) for k, v in metrics.items(): k = 'val_' + k # denote a validation metric if v.numel() > 1: k += '-avg' v = v.float().mean() metrics_dict.update({k:v.detach().item()}) audit_validation_metrics[subgroup] = metrics_dict # after iterating through validation set, write summaries for subgroup, metrics in audit_validation_metrics.items(): for metric_name, metric_value in metrics.items(): writer.add_scalar( '{}/{}'.format(subgroup, metric_name), metric_value, iteration) writer.add_scalar('val_elbo', val_elbo_running_mean.avg, iteration) writer.add_scalar('val_tc', val_tc_running_mean.avg, iteration) for k, v in val_clf_acc_meters.items(): writer.add_scalar(k, v.avg, iteration) ################################################################ # finally, plot training and test ELBOs if args.visdom: display_samples(vae, x, vis) plot_elbo(train_elbo, vis) plot_tc(train_tc, vis) utils.save_checkpoint({ 'state_dict': vae.state_dict(), 'args': args}, args.save, iteration // len(loaders['train'])) eval('plot_vs_gt_' + args.dataset)(vae, loaders['train'].dataset, os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration))) # Report statistics after training vae.eval() utils.save_checkpoint({ 'state_dict': vae.state_dict(), 'args': args}, args.save, 0) dataset_loader = DataLoader(loaders['train'].dataset, batch_size=1000, num_workers=1, shuffle=False) if False: logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \ elbo_decomposition(vae, dataset_loader) torch.save({ 'logpx': logpx, 'dependence': dependence, 'information': information, 'dimwise_kl': dimwise_kl, 'analytical_cond_kl': analytical_cond_kl, 'marginal_entropies': marginal_entropies, 'joint_entropy': joint_entropy }, os.path.join(args.save, 'elbo_decomposition.pth')) eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png')) for file in [open(os.path.join(args.save, 'done'), 'w'), sys.stdout]: print('done', file=file) return vae