Ejemplo n.º 1
0
def setup_data_loaders(args, use_cuda=False):
    if args.dataset == 'shapes':
        train_set = dset.Shapes()
    elif args.dataset == 'faces':
        train_set = dset.Faces()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    kwargs = {'num_workers': 4, 'pin_memory': use_cuda}
    train_loader = DataLoader(dataset=train_set,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    return train_loader
Ejemplo n.º 2
0
def setup_data_loaders(args_dataset, batch_size, use_cuda=False, len_subset=0):
    if args_dataset == 'shapes':
        train_set = dset.Shapes(len_subset=len_subset)
    elif args_dataset == 'faces':
        train_set = dset.Faces()
    elif args_dataset == 'cars3d':
        train_set = dset.Cars3d()
    elif args_dataset == 'celeba':
        train_set = dset.CelebA()
    elif args_dataset == '3dchairs':
        train_set = dset.Chairs()
    else:
        raise ValueError('Unknown dataset ' + str(args_dataset))

    kwargs = {'num_workers': 4, 'pin_memory': use_cuda}
    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              **kwargs)
    return train_loader
Ejemplo n.º 3
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument(
        '-d',
        '--dataset',
        default='shapes',
        type=str,
        help='dataset name',
        choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'lpnorm', 'lpnested'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument(
        '--num-iterations',
        default=0,
        type=int,
        help='number of iterations (overrides number of epochs if >0)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('-p',
                        '--pnorm',
                        default=4.0 / 3.0,
                        type=float,
                        help='p value of the Lp-norm')
    parser.add_argument(
        '--pnested',
        default='',
        type=str,
        help=
        'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]'
    )
    parser.add_argument(
        '--isa',
        default='',
        type=str,
        help=
        'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]'
    )
    parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA')
    parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA')
    parser.add_argument('--n1', default=6, type=int, help='n1 of ISA')
    parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA')
    parser.add_argument('--n2', default=6, type=int, help='n2 of ISA')
    parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA')
    parser.add_argument('--n3', default=6, type=int, help='n3 of ISA')
    parser.add_argument('--scale',
                        default=1.0,
                        type=float,
                        help='scale of LpNested distribution')
    parser.add_argument('--q-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace'])
    parser.add_argument('--x-dist',
                        default='bernoulli',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--id', default='1')
    parser.add_argument(
        '--seed',
        default=-1,
        type=int,
        help=
        'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)'
    )
    parser.add_argument('--log_freq',
                        default=200,
                        type=int,
                        help='num iterations per log')
    parser.add_argument('--use-mse-loss', action='store_true')
    parser.add_argument('--mse-sigma',
                        default=0.01,
                        type=float,
                        help='sigma of mean squared error loss')
    parser.add_argument('--dip', action='store_true', help='use DIP-VAE')
    parser.add_argument('--dip-type',
                        default=1,
                        type=int,
                        help='DIP type (1 or 2)')
    parser.add_argument('--lambda-od',
                        default=2.0,
                        type=float,
                        help='DIP: lambda weight off-diagonal')
    parser.add_argument('--clip',
                        default=0.0,
                        type=float,
                        help='Gradient clipping (0 disabled)')
    parser.add_argument('--test', action='store_true', help='run test')
    parser.add_argument(
        '--trainingsetsize',
        default=0,
        type=int,
        help='Subsample the trainingset (0 use original training data)')
    args = parser.parse_args()

    # initialize seeds for reproducibility
    if not args.seed == -1:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.gpu != -1:
        print('Using CUDA device {}'.format(args.gpu))
        torch.cuda.set_device(args.gpu)
        use_cuda = True
    else:
        print('CUDA disabled')
        use_cuda = False

    # data loader
    train_loader = setup_data_loaders(args.dataset,
                                      args.batch_size,
                                      use_cuda=use_cuda,
                                      len_subset=args.trainingsetsize)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
    elif args.dist == 'lpnested':
        if not args.isa == '':
            pnested = parseISA(ast.literal_eval(args.isa))
        elif not args.pnested == '':
            pnested = ast.literal_eval(args.pnested)
        else:
            pnested = parseISA([
                args.p0,
                [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)]
            ])

        print('using Lp-nested prior, pnested = ({}) {}'.format(
            type(pnested), pnested))
        prior_dist = LpNestedAdapter(p=pnested, scale=args.scale)
        args.latent_dim = prior_dist.dimz()
        print('using Lp-nested prior, changed latent dimension to {}'.format(
            args.latent_dim))
    elif args.dist == 'lpnorm':
        prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim],
                                     scale=args.scale)

    if args.q_dist == 'normal':
        q_dist = dist.Normal()
    elif args.q_dist == 'laplace':
        q_dist = dist.Laplace()

    if args.x_dist == 'normal':
        x_dist = dist.Normal(sigma=args.mse_sigma)
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()

    if args.dip_type == 1:
        lambda_d = 10.0 * args.lambda_od
    else:
        lambda_d = args.lambda_od

    vae = VAE(z_dim=args.latent_dim,
              use_cuda=use_cuda,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              dataset=args.dataset,
              mse_sigma=args.mse_sigma,
              DIP=args.dip,
              DIP_type=args.dip_type,
              lambda_od=args.lambda_od,
              lambda_d=lambda_d)

    # setup the optimizer
    optimizer = optim.Adam([{
        'params': vae.parameters()
    }, {
        'params': prior_dist.parameters(),
        'lr': 5e-4
    }],
                           lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    if args.num_iterations == 0:
        num_iterations = len(train_loader) * args.num_epochs
    else:
        num_iterations = args.num_iterations
    iteration = 0
    obj_best_snapshot = float('-inf')
    best_checkpoint_updated = False

    trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv')
    if not os.path.exists(trainingcurve_filename):
        with open(trainingcurve_filename, 'w') as fd:
            fd.write(
                'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n'
            )

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    nan_detected = False
    while iteration < num_iterations and not nan_detected:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if use_cuda:
                x = x.cuda()  # async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            #with autograd.detect_anomaly():
            obj, elbo, logpx = vae.elbo(prior_dist,
                                        x,
                                        dataset_size,
                                        use_mse_loss=args.use_mse_loss,
                                        mse_sigma=args.mse_sigma)
            if utils.isnan(obj).any():
                print('NaN spotted in objective.')
                print('lpnested: {}'.format(prior_dist.prior.p))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))
                #raise ValueError('NaN spotted in objective.')
                nan_detected = True
                break
            elbo_running_mean.update(elbo.mean().item())

            # save checkpoint of best ELBO
            if obj.mean().item() > obj_best_snapshot:
                obj_best_snapshot = obj.mean().item()
                best_checkpoint = {
                    'state_dict': vae.state_dict(),
                    'state_dict_prior_dist': prior_dist.state_dict(),
                    'args': args,
                    'iteration': iteration,
                    'obj': obj_best_snapshot,
                    'elbo': elbo.mean().item(),
                    'logpx': logpx.mean().item()
                }
                best_checkpoint_updated = True

            #with autograd.detect_anomaly():
            obj.mean().mul(-1).backward()

            gradients = list(
                filter(lambda p: p.grad is not None, vae.parameters()))

            if args.clip > 0:
                torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip)

            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                time_ = time.time() - batch_time
                print(
                    '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)'
                    % (iteration, num_iterations, time_, vae.beta, vae.lamb,
                       obj.mean().item(), logpx.mean().item(),
                       elbo_running_mean.val, elbo_running_mean.avg))

                p0, p1list = backwardsParseISA(prior_dist.prior.p)
                print('lpnested: {}, {}'.format(p0, p1list))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))

                with open(os.path.join(args.save, 'trainingcurve.csv'),
                          'a') as fd:
                    fd.write('{},{},{},{},{}\n'.format(iteration,
                                                       num_iterations, time_,
                                                       elbo_running_mean.val,
                                                       elbo_running_mean.avg))

                if best_checkpoint_updated:
                    print(
                        'Update best checkpoint [iteration %03d] training ELBO: %.4f'
                        % (best_checkpoint['iteration'],
                           best_checkpoint['elbo']))
                    utils.save_checkpoint(best_checkpoint, args.save, 0)
                    best_checkpoint_updated = False

                vae.eval()
                prior_dist.eval()

                # plot training and test ELBOs
                if args.visdom:
                    if args.dataset == 'celeba':
                        num_channels = 3
                    else:
                        num_channels = 1
                    display_samples(vae, prior_dist, x, vis, num_channels)
                    plot_elbo(train_elbo, vis)

                if iteration % (10 * args.log_freq) == 0:
                    utils.save_checkpoint(
                        {
                            'state_dict': vae.state_dict(),
                            'state_dict_prior_dist': prior_dist.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'args': args,
                            'iteration': iteration,
                            'obj': obj.mean().item(),
                            'torch_random_state': torch.get_rng_state(),
                            'numpy_random_state': np.random.get_state()
                        },
                        args.save,
                        prefix='latest-optimizer-model-')
                    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
                        eval('plot_vs_gt_' + args.dataset)(
                            vae, train_loader.dataset,
                            os.path.join(
                                args.save,
                                'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics of best snapshot after training
    vae.load_state_dict(best_checkpoint['state_dict'])
    prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist'])

    vae.eval()
    prior_dist.eval()

    if args.dataset == 'shapes':
        data_set = dset.Shapes()
    elif args.dataset == 'faces':
        data_set = dset.Faces()
    elif args.dataset == 'cars3d':
        data_set = dset.Cars3d()
    elif args.dataset == 'celeba':
        data_set = dset.CelebA()
    elif args.dataset == '3dchairs':
        data_set = dset.Chairs()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    print("loaded dataset {} of size {}".format(args.dataset, len(data_set)))

    dataset_loader = DataLoader(data_set,
                                batch_size=1000,
                                num_workers=0,
                                shuffle=False)

    logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \
        elbo_decomposition(vae, prior_dist, dataset_loader)
    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': elbo_marginal_entropies,
            'joint_entropy': elbo_joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    print('logpx: {:.2f}'.format(logpx))
    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
        eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset,
                                           os.path.join(
                                               args.save, 'gt_vs_latent.png'))

        metric, metric_marginal_entropies, metric_cond_entropies = eval(
            'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
                vae, dataset_loader.dataset)
        torch.save(
            {
                'args': args,
                'metric': metric,
                'marginal_entropies': metric_marginal_entropies,
                'cond_entropies': metric_cond_entropies,
            }, os.path.join(args.save, 'disentanglement_metric.pth'))
        print('MIG: {:.2f}'.format(metric))

        if args.dist == 'lpnested':
            p0, p1list = backwardsParseISA(prior_dist.prior.p)
            print('p0: {}'.format(p0))
            print('p1: {}'.format(p1list))
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                    'p0': p0,
                    'p1': p1list
                }, os.path.join(args.save, 'combined_data.pth'))
        else:
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                }, os.path.join(args.save, 'combined_data.pth'))

        if args.dist == 'lpnested':
            if args.dataset == 'shapes':
                eval('plot_vs_gt_' + args.dataset)(
                    vae,
                    dataset_loader.dataset,
                    os.path.join(args.save, 'gt_vs_grouped_latent.png'),
                    eval_subspaces=True)

                metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval(
                    'disentanglement_metrics.mutual_info_metric_' +
                    args.dataset)(vae,
                                  dataset_loader.dataset,
                                  eval_subspaces=True)
                torch.save(
                    {
                        'args': args,
                        'metric': metric_subspaces,
                        'marginal_entropies':
                        metric_marginal_entropies_subspaces,
                        'cond_entropies': metric_cond_entropies_subspaces,
                    },
                    os.path.join(args.save,
                                 'disentanglement_metric_subspaces.pth'))
                print('MIG grouped by subspaces: {:.2f}'.format(
                    metric_subspaces))

                torch.save(
                    {
                        'args': args,
                        'logpx': logpx,
                        'dependence': dependence,
                        'information': information,
                        'dimwise_kl': dimwise_kl,
                        'analytical_cond_kl': analytical_cond_kl,
                        'elbo_marginal_entropies': elbo_marginal_entropies,
                        'elbo_joint_entropy': elbo_joint_entropy,
                        'metric': metric,
                        'metric_marginal_entropies': metric_marginal_entropies,
                        'metric_cond_entropies': metric_cond_entropies,
                        'metric_subspaces': metric_subspaces,
                        'metric_marginal_entropies_subspaces':
                        metric_marginal_entropies_subspaces,
                        'metric_cond_entropies_subspaces':
                        metric_cond_entropies_subspaces,
                        'p0': p0,
                        'p1': p1list
                    }, os.path.join(args.save, 'combined_data.pth'))

    return vae
Ejemplo n.º 4
0
def setup_data_loaders(dataset, batch_size = 2048, use_cuda=True):
    if dataset == 'shapes':
        train_set = dset.Shapes()
    elif dataset == 'faces':
        train_set = dset.Faces()
    return train_set