Пример #1
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='shapes', type=str, help='dataset name',
        choices=['shapes', 'faces'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    args = parser.parse_args()

    # torch.cuda.set_device(args.gpu)

    # data loader
    train_loader = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist,
        include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss)

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    num_iterations = len(train_loader) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    while iteration < num_iterations:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            obj, elbo = vae.elbo(x, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            print("obj value: ", obj.mean().mul(-1).cpu())
            elbo_running_mean.update(elbo.mean().item())
            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                print('[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg))

                vae.eval()

                # plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, 0)
                eval('plot_vs_gt_' + args.dataset)(vae, train_loader.dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(train_loader.dataset, batch_size=10, num_workers=1, shuffle=False)
    logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
        elbo_decomposition(vae, dataset_loader)
    torch.save({
        'logpx': logpx,
        'dependence': dependence,
        'information': information,
        'dimwise_kl': dimwise_kl,
        'analytical_cond_kl': analytical_cond_kl,
        'marginal_entropies': marginal_entropies,
        'joint_entropy': joint_entropy
    }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))
    return vae
Пример #2
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument(
        '-d',
        '--dataset',
        default='shapes',
        type=str,
        help='dataset name',
        choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'lpnorm', 'lpnested'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument(
        '--num-iterations',
        default=0,
        type=int,
        help='number of iterations (overrides number of epochs if >0)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('-p',
                        '--pnorm',
                        default=4.0 / 3.0,
                        type=float,
                        help='p value of the Lp-norm')
    parser.add_argument(
        '--pnested',
        default='',
        type=str,
        help=
        'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]'
    )
    parser.add_argument(
        '--isa',
        default='',
        type=str,
        help=
        'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]'
    )
    parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA')
    parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA')
    parser.add_argument('--n1', default=6, type=int, help='n1 of ISA')
    parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA')
    parser.add_argument('--n2', default=6, type=int, help='n2 of ISA')
    parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA')
    parser.add_argument('--n3', default=6, type=int, help='n3 of ISA')
    parser.add_argument('--scale',
                        default=1.0,
                        type=float,
                        help='scale of LpNested distribution')
    parser.add_argument('--q-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace'])
    parser.add_argument('--x-dist',
                        default='bernoulli',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--id', default='1')
    parser.add_argument(
        '--seed',
        default=-1,
        type=int,
        help=
        'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)'
    )
    parser.add_argument('--log_freq',
                        default=200,
                        type=int,
                        help='num iterations per log')
    parser.add_argument('--use-mse-loss', action='store_true')
    parser.add_argument('--mse-sigma',
                        default=0.01,
                        type=float,
                        help='sigma of mean squared error loss')
    parser.add_argument('--dip', action='store_true', help='use DIP-VAE')
    parser.add_argument('--dip-type',
                        default=1,
                        type=int,
                        help='DIP type (1 or 2)')
    parser.add_argument('--lambda-od',
                        default=2.0,
                        type=float,
                        help='DIP: lambda weight off-diagonal')
    parser.add_argument('--clip',
                        default=0.0,
                        type=float,
                        help='Gradient clipping (0 disabled)')
    parser.add_argument('--test', action='store_true', help='run test')
    parser.add_argument(
        '--trainingsetsize',
        default=0,
        type=int,
        help='Subsample the trainingset (0 use original training data)')
    args = parser.parse_args()

    # initialize seeds for reproducibility
    if not args.seed == -1:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.gpu != -1:
        print('Using CUDA device {}'.format(args.gpu))
        torch.cuda.set_device(args.gpu)
        use_cuda = True
    else:
        print('CUDA disabled')
        use_cuda = False

    # data loader
    train_loader = setup_data_loaders(args.dataset,
                                      args.batch_size,
                                      use_cuda=use_cuda,
                                      len_subset=args.trainingsetsize)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
    elif args.dist == 'lpnested':
        if not args.isa == '':
            pnested = parseISA(ast.literal_eval(args.isa))
        elif not args.pnested == '':
            pnested = ast.literal_eval(args.pnested)
        else:
            pnested = parseISA([
                args.p0,
                [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)]
            ])

        print('using Lp-nested prior, pnested = ({}) {}'.format(
            type(pnested), pnested))
        prior_dist = LpNestedAdapter(p=pnested, scale=args.scale)
        args.latent_dim = prior_dist.dimz()
        print('using Lp-nested prior, changed latent dimension to {}'.format(
            args.latent_dim))
    elif args.dist == 'lpnorm':
        prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim],
                                     scale=args.scale)

    if args.q_dist == 'normal':
        q_dist = dist.Normal()
    elif args.q_dist == 'laplace':
        q_dist = dist.Laplace()

    if args.x_dist == 'normal':
        x_dist = dist.Normal(sigma=args.mse_sigma)
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()

    if args.dip_type == 1:
        lambda_d = 10.0 * args.lambda_od
    else:
        lambda_d = args.lambda_od

    vae = VAE(z_dim=args.latent_dim,
              use_cuda=use_cuda,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              dataset=args.dataset,
              mse_sigma=args.mse_sigma,
              DIP=args.dip,
              DIP_type=args.dip_type,
              lambda_od=args.lambda_od,
              lambda_d=lambda_d)

    # setup the optimizer
    optimizer = optim.Adam([{
        'params': vae.parameters()
    }, {
        'params': prior_dist.parameters(),
        'lr': 5e-4
    }],
                           lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    if args.num_iterations == 0:
        num_iterations = len(train_loader) * args.num_epochs
    else:
        num_iterations = args.num_iterations
    iteration = 0
    obj_best_snapshot = float('-inf')
    best_checkpoint_updated = False

    trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv')
    if not os.path.exists(trainingcurve_filename):
        with open(trainingcurve_filename, 'w') as fd:
            fd.write(
                'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n'
            )

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    nan_detected = False
    while iteration < num_iterations and not nan_detected:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if use_cuda:
                x = x.cuda()  # async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            #with autograd.detect_anomaly():
            obj, elbo, logpx = vae.elbo(prior_dist,
                                        x,
                                        dataset_size,
                                        use_mse_loss=args.use_mse_loss,
                                        mse_sigma=args.mse_sigma)
            if utils.isnan(obj).any():
                print('NaN spotted in objective.')
                print('lpnested: {}'.format(prior_dist.prior.p))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))
                #raise ValueError('NaN spotted in objective.')
                nan_detected = True
                break
            elbo_running_mean.update(elbo.mean().item())

            # save checkpoint of best ELBO
            if obj.mean().item() > obj_best_snapshot:
                obj_best_snapshot = obj.mean().item()
                best_checkpoint = {
                    'state_dict': vae.state_dict(),
                    'state_dict_prior_dist': prior_dist.state_dict(),
                    'args': args,
                    'iteration': iteration,
                    'obj': obj_best_snapshot,
                    'elbo': elbo.mean().item(),
                    'logpx': logpx.mean().item()
                }
                best_checkpoint_updated = True

            #with autograd.detect_anomaly():
            obj.mean().mul(-1).backward()

            gradients = list(
                filter(lambda p: p.grad is not None, vae.parameters()))

            if args.clip > 0:
                torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip)

            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                time_ = time.time() - batch_time
                print(
                    '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)'
                    % (iteration, num_iterations, time_, vae.beta, vae.lamb,
                       obj.mean().item(), logpx.mean().item(),
                       elbo_running_mean.val, elbo_running_mean.avg))

                p0, p1list = backwardsParseISA(prior_dist.prior.p)
                print('lpnested: {}, {}'.format(p0, p1list))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))

                with open(os.path.join(args.save, 'trainingcurve.csv'),
                          'a') as fd:
                    fd.write('{},{},{},{},{}\n'.format(iteration,
                                                       num_iterations, time_,
                                                       elbo_running_mean.val,
                                                       elbo_running_mean.avg))

                if best_checkpoint_updated:
                    print(
                        'Update best checkpoint [iteration %03d] training ELBO: %.4f'
                        % (best_checkpoint['iteration'],
                           best_checkpoint['elbo']))
                    utils.save_checkpoint(best_checkpoint, args.save, 0)
                    best_checkpoint_updated = False

                vae.eval()
                prior_dist.eval()

                # plot training and test ELBOs
                if args.visdom:
                    if args.dataset == 'celeba':
                        num_channels = 3
                    else:
                        num_channels = 1
                    display_samples(vae, prior_dist, x, vis, num_channels)
                    plot_elbo(train_elbo, vis)

                if iteration % (10 * args.log_freq) == 0:
                    utils.save_checkpoint(
                        {
                            'state_dict': vae.state_dict(),
                            'state_dict_prior_dist': prior_dist.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'args': args,
                            'iteration': iteration,
                            'obj': obj.mean().item(),
                            'torch_random_state': torch.get_rng_state(),
                            'numpy_random_state': np.random.get_state()
                        },
                        args.save,
                        prefix='latest-optimizer-model-')
                    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
                        eval('plot_vs_gt_' + args.dataset)(
                            vae, train_loader.dataset,
                            os.path.join(
                                args.save,
                                'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics of best snapshot after training
    vae.load_state_dict(best_checkpoint['state_dict'])
    prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist'])

    vae.eval()
    prior_dist.eval()

    if args.dataset == 'shapes':
        data_set = dset.Shapes()
    elif args.dataset == 'faces':
        data_set = dset.Faces()
    elif args.dataset == 'cars3d':
        data_set = dset.Cars3d()
    elif args.dataset == 'celeba':
        data_set = dset.CelebA()
    elif args.dataset == '3dchairs':
        data_set = dset.Chairs()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    print("loaded dataset {} of size {}".format(args.dataset, len(data_set)))

    dataset_loader = DataLoader(data_set,
                                batch_size=1000,
                                num_workers=0,
                                shuffle=False)

    logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \
        elbo_decomposition(vae, prior_dist, dataset_loader)
    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': elbo_marginal_entropies,
            'joint_entropy': elbo_joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    print('logpx: {:.2f}'.format(logpx))
    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
        eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset,
                                           os.path.join(
                                               args.save, 'gt_vs_latent.png'))

        metric, metric_marginal_entropies, metric_cond_entropies = eval(
            'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
                vae, dataset_loader.dataset)
        torch.save(
            {
                'args': args,
                'metric': metric,
                'marginal_entropies': metric_marginal_entropies,
                'cond_entropies': metric_cond_entropies,
            }, os.path.join(args.save, 'disentanglement_metric.pth'))
        print('MIG: {:.2f}'.format(metric))

        if args.dist == 'lpnested':
            p0, p1list = backwardsParseISA(prior_dist.prior.p)
            print('p0: {}'.format(p0))
            print('p1: {}'.format(p1list))
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                    'p0': p0,
                    'p1': p1list
                }, os.path.join(args.save, 'combined_data.pth'))
        else:
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                }, os.path.join(args.save, 'combined_data.pth'))

        if args.dist == 'lpnested':
            if args.dataset == 'shapes':
                eval('plot_vs_gt_' + args.dataset)(
                    vae,
                    dataset_loader.dataset,
                    os.path.join(args.save, 'gt_vs_grouped_latent.png'),
                    eval_subspaces=True)

                metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval(
                    'disentanglement_metrics.mutual_info_metric_' +
                    args.dataset)(vae,
                                  dataset_loader.dataset,
                                  eval_subspaces=True)
                torch.save(
                    {
                        'args': args,
                        'metric': metric_subspaces,
                        'marginal_entropies':
                        metric_marginal_entropies_subspaces,
                        'cond_entropies': metric_cond_entropies_subspaces,
                    },
                    os.path.join(args.save,
                                 'disentanglement_metric_subspaces.pth'))
                print('MIG grouped by subspaces: {:.2f}'.format(
                    metric_subspaces))

                torch.save(
                    {
                        'args': args,
                        'logpx': logpx,
                        'dependence': dependence,
                        'information': information,
                        'dimwise_kl': dimwise_kl,
                        'analytical_cond_kl': analytical_cond_kl,
                        'elbo_marginal_entropies': elbo_marginal_entropies,
                        'elbo_joint_entropy': elbo_joint_entropy,
                        'metric': metric,
                        'metric_marginal_entropies': metric_marginal_entropies,
                        'metric_cond_entropies': metric_cond_entropies,
                        'metric_subspaces': metric_subspaces,
                        'metric_marginal_entropies_subspaces':
                        metric_marginal_entropies_subspaces,
                        'metric_cond_entropies_subspaces':
                        metric_cond_entropies_subspaces,
                        'p0': p0,
                        'p1': p1list
                    }, os.path.join(args.save, 'combined_data.pth'))

    return vae
Пример #3
0
def evaluate(args, outputdir, vae, dataset, prefix=''):
    if os.path.exists(os.path.join(outputdir, prefix + 'combined_data.pth')):
        return

    # Report statistics
    vae.eval()
    dataset_loader = DataLoader(dataset,
                                batch_size=1000,
                                num_workers=0,
                                shuffle=True)

    if not os.path.exists(outputdir):
        os.makedirs(outputdir)

    logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \
        elbo_decomposition(vae, dataset_loader)
    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': elbo_marginal_entropies,
            'joint_entropy': elbo_joint_entropy
        }, os.path.join(outputdir, prefix + 'elbo_decomposition.pth'))

    metric, metric_marginal_entropies, metric_cond_entropies = eval(
        'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
            vae, dataset_loader.dataset, eval_subspaces=False)
    torch.save(
        {
            'args': args,
            'metric': metric,
            'marginal_entropies': metric_marginal_entropies,
            'cond_entropies': metric_cond_entropies,
        }, os.path.join(outputdir, prefix + 'disentanglement_metric.pth'))
    print('logpx: {:.2f}'.format(logpx))
    print('MIG: {:.2f}'.format(metric))

    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset,
                                       os.path.join(
                                           outputdir,
                                           prefix + 'gt_vs_latent.png'))

    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'elbo_marginal_entropies': elbo_marginal_entropies,
            'elbo_joint_entropy': elbo_joint_entropy,
            'metric': metric,
            'metric_marginal_entropies': metric_marginal_entropies,
            'metric_cond_entropies': metric_cond_entropies,
        }, os.path.join(outputdir, prefix + 'combined_data.pth'))

    if args.dist == 'lpnested':
        eval('plot_vs_gt_' + args.dataset)(vae,
                                           dataset_loader.dataset,
                                           os.path.join(
                                               outputdir,
                                               'gt_vs_grouped_latent.png'),
                                           eval_subspaces=True)

        metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval(
            'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
                vae, dataset_loader.dataset, eval_subspaces=True)
        torch.save(
            {
                'args': args,
                'metric': metric_subspaces,
                'marginal_entropies': metric_marginal_entropies_subspaces,
                'cond_entropies': metric_cond_entropies_subspaces,
            }, os.path.join(outputdir, 'disentanglement_metric_subspaces.pth'))
        print('MIG grouped by subspaces: {:.2f}'.format(metric_subspaces))

        torch.save(
            {
                'args': args,
                'logpx': logpx,
                'dependence': dependence,
                'information': information,
                'dimwise_kl': dimwise_kl,
                'analytical_cond_kl': analytical_cond_kl,
                'elbo_marginal_entropies': elbo_marginal_entropies,
                'elbo_joint_entropy': elbo_joint_entropy,
                'metric': metric,
                'metric_marginal_entropies': metric_marginal_entropies,
                'metric_cond_entropies': metric_cond_entropies,
                'metric_subspaces': metric_subspaces,
                'metric_marginal_entropies_subspaces':
                metric_marginal_entropies_subspaces,
                'metric_cond_entropies_subspaces':
                metric_cond_entropies_subspaces
            }, os.path.join(outputdir, 'combined_data.pth'))
Пример #4
0
    def load_model_and_dataset(checkpt_filename):
        print('Loading model and dataset.')
        checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage)
        args = checkpt['args']
        state_dict = checkpt['state_dict']

        # model
        if not hasattr(args, 'dist') or args.dist == 'normal':
            prior_dist = dist.Normal()
            q_dist = dist.Normal()
        elif args.dist == 'laplace':
            prior_dist = dist.Laplace()
            q_dist = dist.Laplace()
        elif args.dist == 'flow':
            prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=4)
            q_dist = dist.Normal()
        vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)
        vae.load_state_dict(state_dict, strict=False)

        # dataset loader
        loader = setup_data_loaders(args)
        return vae, loader, args

    z_inds = list(map(int, args.zs.split(','))) if args.zs is not None else None
    torch.cuda.set_device(args.gpu)
    vae, dataset_loader, cpargs = load_model_and_dataset(args.checkpt)
    if args.elbo_decomp:
        elbo_decomposition(vae, dataset_loader)
    eval('plot_vs_gt_' + cpargs.dataset)(vae, dataset_loader.dataset, args.save, z_inds)
Пример #5
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='celeba', type=str, help='dataset name',
        choices=['celeba'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=100, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--beta_sens', default=20, type=float, help='Relative importance of predicting sensitive attributes')
    #parser.add_argument('--sens_idx', default=[13, 15, 20], type=list, help='Relative importance of predicting sensitive attributes')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--clf_samps', action='store_true')
    parser.add_argument('--clf_means', action='store_false', dest='clf_samps')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='betatcvae-celeba')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    parser.add_argument('--audit', action='store_true',
            help='after each epoch, audit the repr wrt fair clf task')
    args = parser.parse_args()
    print(args)
    
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    writer = SummaryWriter(args.save)
    writer.add_text('args', json.dumps(vars(args), sort_keys=True, indent=4))

    log_file = os.path.join(args.save, 'train.log')
    if os.path.exists(log_file):
        os.remove(log_file)

    print(vars(args))
    print(vars(args), file=open(log_file, 'w'))

    torch.cuda.set_device(args.gpu)

    # data loader
    loaders = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli()
    a_dist = dist.Bernoulli()
    vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, 
            q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, 
            tcvae=args.tcvae, conv=args.conv, mss=args.mss, 
            n_chan=3 if args.dataset == 'celeba' else 1, sens_idx=SENS_IDX,
            x_dist=x_dist, a_dist=a_dist, clf_samps=args.clf_samps)

    if args.audit:
        audit_label_fn = get_label_fn(
                dict(data=dict(name='celeba', label_fn='H'))
                )
        audit_repr_fns = dict()
        audit_attr_fns = dict()
        audit_models = dict()
        audit_train_metrics = dict()
        audit_validation_metrics = dict()
        for attr_fn_name in CELEBA_SENS_IDX.keys():
            model = MLPClassifier(args.latent_dim, 1000, 2)
            model.cuda()
            audit_models[attr_fn_name] = model
            audit_repr_fns[attr_fn_name] = get_repr_fn(
                dict(data=dict(
                    name='celeba', repr_fn='remove_all', attr_fn=attr_fn_name))
                )
            audit_attr_fns[attr_fn_name] = get_attr_fn(
                dict(data=dict(name='celeba', attr_fn=attr_fn_name))
                )

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)
    if args.audit:
        Adam = optim.Adam
        audit_optimizers = dict()
        for k, v in audit_models.items():
            audit_optimizers[k] = Adam(v.parameters(), lr=args.learning_rate)


    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=3776)

    train_elbo = []
    train_tc = []

    # training loop
    dataset_size = len(loaders['train'].dataset)
    num_iterations = len(loaders['train']) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    tc_running_mean = utils.RunningAverageMeter()
    clf_acc_meters = {'clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}

    val_elbo_running_mean = utils.RunningAverageMeter()
    val_tc_running_mean = utils.RunningAverageMeter()
    val_clf_acc_meters = {'val_clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}


    while iteration < num_iterations:
        bar = tqdm(range(len(loaders['train'])))
        for i, (x, a) in enumerate(loaders['train']):
            bar.update()
            iteration += 1
            batch_time = time.time()
            vae.train()
            #anneal_kl(args, vae, iteration)  # TODO try annealing beta/beta_sens
            vae.beta = args.beta
            vae.beta_sens = args.beta_sens
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            a = a.float()
            a = a.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            a = Variable(a)
            # do ELBO gradient and accumulate loss
            obj, elbo, metrics = vae.elbo(x, a, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            elbo_running_mean.update(elbo.mean().data.item())
            tc_running_mean.update(metrics['tc'])
            for (s, meter), (_, acc) in zip(clf_acc_meters.items(), metrics.items()):
                clf_acc_meters[s].update(acc.data.item())
            optimizer.step()

            if args.audit:
                for model in audit_models.values():
                    model.train()
                # now re-encode x and take a step to train each audit classifier
                for opt in audit_optimizers.values():
                    opt.zero_grad()
                with torch.no_grad():
                    zs, z_params = vae.encode(x)
                    if args.clf_samps:
                        z = zs
                    else:
                        z_mu = z_params.select(-1, 0)
                        z = z_mu
                    a_all = a
                for subgroup, model in audit_models.items():
                    # noise out sensitive dims of latent code
                    z_ = z.clone()
                    a_all_ = a_all.clone()
                    # subsample to just sens attr of interest for this subgroup
                    a_ = audit_attr_fns[subgroup](a_all_)
                    # noise out sensitive dims for this subgroup
                    z_ = audit_repr_fns[subgroup](z_, None, None)
                    y_ = audit_label_fn(a_all_).long()

                    loss, _, metrics = model(z_, y_, a_)
                    loss.backward()
                    audit_optimizers[subgroup].step()
                    metrics_dict = {}
                    metrics_dict.update(loss=loss.detach().item())
                    for k, v in metrics.items():
                        if v.numel() > 1:
                            k += '-avg'
                            v = v.float().mean()
                        metrics_dict.update({k:v.detach().item()})
                    audit_train_metrics[subgroup] = metrics_dict

            # report training diagnostics
            if iteration % args.log_freq == 0:
                if args.audit:
                    for subgroup, metrics in audit_train_metrics.items():
                        for metric_name, metric_value in metrics.items():
                            writer.add_scalar(
                                    '{}/{}'.format(subgroup, metric_name),
                                    metric_value, iteration)

                train_elbo.append(elbo_running_mean.avg)
                writer.add_scalar('train_elbo', elbo_running_mean.avg, iteration)
                train_tc.append(tc_running_mean.avg)
                writer.add_scalar('train_tc', tc_running_mean.avg, iteration)
                msg = '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) training TC %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg,
                    tc_running_mean.val, tc_running_mean.avg)
                for k, v in clf_acc_meters.items():
                    msg += ' {}: {:.2f}'.format(k, v.avg)
                    writer.add_scalar(k, v.avg, iteration)
                print(msg)
                print(msg, file=open(log_file, 'a'))

                vae.eval()
                ################################################################
                # evaluate validation metrics on vae and auditors
                for x, a in loaders['validation']:
                    # transfer to GPU
                    x = x.cuda(async=True)
                    a = a.float()
                    a = a.cuda(async=True)
                    # wrap the mini-batch in a PyTorch Variable
                    x = Variable(x)
                    a = Variable(a)
                    # do ELBO gradient and accumulate loss
                    obj, elbo, metrics = vae.elbo(x, a, dataset_size)
                    if utils.isnan(obj).any():
                        raise ValueError('NaN spotted in objective.')
                    #
                    val_elbo_running_mean.update(elbo.mean().data.item())
                    val_tc_running_mean.update(metrics['tc'])
                    for (s, meter), (_, acc) in zip(
                            val_clf_acc_meters.items(), metrics.items()):
                        val_clf_acc_meters[s].update(acc.data.item())

                if args.audit:
                    for model in audit_models.values():
                        model.eval()
                    with torch.no_grad():
                        zs, z_params = vae.encode(x)
                        if args.clf_samps:
                            z = zs
                        else:
                            z_mu = z_params.select(-1, 0)
                            z = z_mu
                        a_all = a
                    for subgroup, model in audit_models.items():
                        # noise out sensitive dims of latent code
                        z_ = z.clone()
                        a_all_ = a_all.clone()
                        # subsample to just sens attr of interest for this subgroup
                        a_ = audit_attr_fns[subgroup](a_all_)
                        # noise out sensitive dims for this subgroup
                        z_ = audit_repr_fns[subgroup](z_, None, None)
                        y_ = audit_label_fn(a_all_).long()

                        loss, _, metrics = model(z_, y_, a_)
                        loss.backward()
                        audit_optimizers[subgroup].step()
                        metrics_dict = {}
                        metrics_dict.update(val_loss=loss.detach().item())
                        for k, v in metrics.items():
                            k = 'val_' + k  # denote a validation metric
                            if v.numel() > 1:
                                k += '-avg'
                                v = v.float().mean()
                            metrics_dict.update({k:v.detach().item()})
                        audit_validation_metrics[subgroup] = metrics_dict

                # after iterating through validation set, write summaries
                for subgroup, metrics in audit_validation_metrics.items():
                    for metric_name, metric_value in metrics.items():
                        writer.add_scalar(
                                '{}/{}'.format(subgroup, metric_name),
                                metric_value, iteration)
                writer.add_scalar('val_elbo', val_elbo_running_mean.avg, iteration)
                writer.add_scalar('val_tc', val_tc_running_mean.avg, iteration)
                for k, v in val_clf_acc_meters.items():
                    writer.add_scalar(k, v.avg, iteration)

                ################################################################
                # finally, plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)
                    plot_tc(train_tc, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, iteration // len(loaders['train']))
                eval('plot_vs_gt_' + args.dataset)(vae, loaders['train'].dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(loaders['train'].dataset, batch_size=1000, num_workers=1, shuffle=False)
    if False:
        logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
            elbo_decomposition(vae, dataset_loader)
        torch.save({
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': marginal_entropies,
            'joint_entropy': joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))

    for file in [open(os.path.join(args.save, 'done'), 'w'), sys.stdout]:
        print('done', file=file)

    return vae