Example #1
0
 def variational_loss(input, img, model, z=None):
     mean, logvar = input
     z_samples = model._reparameterize(mean, logvar, z)
     preds = model._dec_forward(img, z_samples)
     nll = utils.log_bernoulli_loss(preds, img)
     kl = utils.kl_loss_diag(mean, logvar)
     return nll + args.beta * kl
Example #2
0
 def variational_loss(input, sents, model, z = None):
   mean, logvar = input
   z_samples = model._reparameterize(mean, logvar, z)
   preds = model._dec_forward(sents, z_samples)
   nll = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(preds.size(1))])
   kl = utils.kl_loss_diag(mean, logvar)
   return nll + args.beta*kl
Example #3
0
def get_agg_kl(data, model, meta_optimizer):
    model.eval()
    criterion = nn.NLLLoss().cuda()
    means = []
    logvars = []
    all_z = []
    for i in range(len(data)):
        sents, length, batch_size = data[i]
        if args.gpu >= 0:
            sents = sents.cuda()
        mean, logvar = model._enc_forward(sents)
        z_samples = model._reparameterize(mean, logvar)
        if args.model == 'savae':
            mean_svi = Variable(mean.data, requires_grad=True)
            logvar_svi = Variable(logvar.data, requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    sents)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model._reparameterize(mean_svi_final, logvar_svi_final)
            preds = model._dec_forward(sents, z_samples)
            nll_svi = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            mean, logvar = mean_svi_final, logvar_svi_final
        means.append(mean.data)
        logvars.append(logvar.data)
        all_z.append(z_samples.data)
    means = torch.cat(means, 0)
    logvars = torch.cat(logvars, 0)
    all_z = torch.cat(all_z, 0)
    N = float(means.size(0))
    mean_prior = torch.zeros(1, means.size(1)).cuda()
    logvar_prior = torch.zeros(1, means.size(1)).cuda()
    agg_kl = 0.
    count = 0.
    for i in range(all_z.size(0)):
        z_i = all_z[i].unsqueeze(0).expand_as(means)
        log_agg_density = utils.log_gaussian(z_i, means,
                                             logvars)  # log q(z|x) for all x
        log_q = utils.logsumexp(log_agg_density, 0)
        log_q = -np.log(N) + log_q
        log_p = utils.log_gaussian(all_z[i].unsqueeze(0), mean_prior,
                                   logvar_prior)
        agg_kl += log_q.sum() - log_p.sum()
        count += 1
    mean_var = mean.var(0)
    print('active units', (mean_var > 0.02).float().sum())
    print(mean_var)

    return agg_kl / count
Example #4
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    all_data = torch.load(args.data_file)
    x_train, x_val, x_test = all_data
    y_size = 1
    y_train = torch.zeros(x_train.size(0), y_size)
    y_val = torch.zeros(x_val.size(0), y_size)
    y_test = torch.zeros(x_test.size(0), y_size)
    train = torch.utils.data.TensorDataset(x_train, y_train)
    val = torch.utils.data.TensorDataset(x_val, y_val)
    test = torch.utils.data.TensorDataset(x_test, y_test)

    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    print('Train data: %d batches' % len(train_loader))
    print('Val data: %d batches' % len(val_loader))
    print('Test data: %d batches' % len(test_loader))
    if args.slurm == 0:
        cuda.set_device(args.gpu)
    if args.model == 'autoreg':
        args.latent_feature_map = 0
    if args.train_from == '':
        model = CNNVAE(img_size=args.img_size,
                       latent_dim=args.latent_dim,
                       enc_layers=args.enc_layers,
                       dec_kernel_size=args.dec_kernel_size,
                       dec_layers=args.dec_layers,
                       latent_feature_map=args.latent_feature_map)
    else:
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']
    print("model architecture")
    print(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.999))

    model.cuda()
    model.train()

    def variational_loss(input, img, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(img, z_samples)
        nll = utils.log_bernoulli_loss(preds, img)
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    epoch = 0
    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    loss_stats = []
    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = 0.1

    if args.test == 1:
        args.beta = 1
        eval(test_loader, model, meta_optimizer)
        exit()

    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        print('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        num_examples = 0
        for b, datum in enumerate(train_loader):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_loader)))
            img, _ = datum
            img = torch.bernoulli(img)
            batch_size = img.size(0)
            img = Variable(img.cuda())
            t += 1
            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(img, None)
                nll_autoreg = utils.log_bernoulli_loss(preds, img)
                train_nll_autoreg += nll_autoreg.data[0] * batch_size
                nll_autoreg.backward()
            elif args.model == 'svi':
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], img, t % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(img, z_samples)
                nll_svi = utils.log_bernoulli_loss(preds, img)
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward()
            else:
                mean, logvar = model._enc_forward(img)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(img, z_samples)
                nll_vae = utils.log_bernoulli_loss(preds, img)
                train_nll_vae += nll_vae.data[0] * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.data[0] * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)

                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)

                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], img, t % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(img, z_samples)
                    nll_svi = utils.log_bernoulli_loss(preds, img)
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            t % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_examples += batch_size
            if t % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                print(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.2f, TrainVAE_NLL: %.2f, TrainVAE_KL: %.4f, TrainVAE_NLLBnd: %.2f, TrainSVI_NLL: %.2f, TrainSVI_KL: %.4f, TrainSVI_NLLBnd: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.3f, Throughput: %.2f examples/sec'
                    %
                    (t, epoch, b + 1, len(train_loader), args.lr,
                     train_nll_autoreg / num_examples,
                     train_nll_vae / num_examples, train_kl_vae / num_examples,
                     (train_nll_vae + train_kl_vae) / num_examples,
                     train_nll_svi / num_examples, train_kl_svi / num_examples,
                     (train_nll_svi + train_kl_svi) / num_examples, param_norm,
                     best_val_nll, best_epoch, args.beta, num_examples /
                     (time.time() - start_time)))
        print('--------------------------------')
        print('Checking validation perf...')
        val_nll = eval(val_loader, model, meta_optimizer)
        loss_stats.append(val_nll)
        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_epoch = epoch
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'optimizer': optimizer,
                'loss_stats': loss_stats
            }
            print('Savaeng checkpoint to %s' % args.checkpoint_path)
            torch.save(checkpoint, args.checkpoint_path)
Example #5
0
def eval(data, model, meta_optimizer):
    model.eval()
    num_examples = 0
    total_nll_autoreg = 0.
    total_nll_vae = 0.
    total_kl_vae = 0.
    total_nll_svi = 0.
    total_kl_svi = 0.
    for datum in data:
        img, _ = datum
        batch_size = img.size(0)
        img = Variable(img.cuda())
        if args.model == 'autoreg':
            preds = model._dec_forward(img, None)
            nll_autoreg = utils.log_bernoulli_loss(preds, img)
            total_nll_autoreg += nll_autoreg.data[0] * batch_size
        elif args.model == 'svi':
            mean_svi = Variable(
                0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            logvar_svi = Variable(
                0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    img)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model._reparameterize(mean_svi_final.detach(),
                                              logvar_svi_final.detach())
            preds = model._dec_forward(img, z_samples)
            nll_svi = utils.log_bernoulli_loss(preds, img)
            total_nll_svi += nll_svi.data[0] * batch_size
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            total_kl_svi += kl_svi.data[0] * batch_size
        else:
            mean, logvar = model._enc_forward(img)
            z_samples = model._reparameterize(mean, logvar)
            preds = model._dec_forward(img, z_samples)
            nll_vae = utils.log_bernoulli_loss(preds, img)
            total_nll_vae += nll_vae.data[0] * batch_size
            kl_vae = utils.kl_loss_diag(mean, logvar)
            total_kl_vae += kl_vae.data[0] * batch_size
            if args.model == 'savae':
                mean_svi = Variable(mean.data, requires_grad=True)
                logvar_svi = Variable(logvar.data, requires_grad=True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        img)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final,
                                                  logvar_svi_final)
                preds = model._dec_forward(img, z_samples.detach())
                nll_svi = utils.log_bernoulli_loss(preds, img)
                total_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                total_kl_svi += kl_svi.data[0] * batch_size
                mean, logvar = mean_svi_final, logvar_svi_final
        num_examples += batch_size

    nll_autoreg = total_nll_autoreg / num_examples
    nll_vae = total_nll_vae / num_examples
    kl_vae = total_kl_vae / num_examples
    nll_bound_vae = (total_nll_vae + total_kl_vae) / num_examples
    nll_svi = total_nll_svi / num_examples
    kl_svi = total_kl_svi / num_examples
    nll_bound_svi = (total_nll_svi + total_kl_svi) / num_examples
    print(
        'AR NLL: %.4f, VAE NLL: %.4f, VAE KL: %.4f, VAE NLL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI NLL BOUND: %.4f'
        % (nll_autoreg, nll_vae, kl_vae, nll_bound_vae, nll_svi, kl_svi,
           nll_bound_svi))
    model.train()
    if args.model == 'autoreg':
        return nll_autoreg
    elif args.model == 'vae':
        return nll_bound_vae
    elif args.model == 'savae' or args.model == 'svi':
        return nll_bound_svi
Example #6
0
def main():
    wandb.init(project="vae-comparison")
    wandb.config.update(args)
    log_step = 0

    # set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # set device
    use_gpu = args.use_gpu and torch.cuda.is_available()
    device = torch.device("cuda" if use_gpu else "cpu")
    print("training on {} device".format("cuda" if use_gpu else "cpu"))

    # load dataset
    train_loader, val_loader, test_loader = load_data(
        dataset=args.dataset,
        batch_size=args.batch_size,
        no_validation=args.no_validation,
        shuffle=args.shuffle,
        data_file=args.data_file)

    # define model or load checkpoint
    if args.train_from == '':
        print('--------------------------------')
        print("initializing new model")
        model = VAE(latent_dim=args.latent_dim)

    else:
        print('--------------------------------')
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    print('--------------------------------')
    print("model architecture")
    print(model)

    # set model for training
    model.to(device)
    model.train()

    # define optimizers and their schedulers
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_enc = torch.optim.Adam(model.enc.parameters(), lr=args.lr)
    optimizer_dec = torch.optim.Adam(model.dec.parameters(), lr=args.lr)
    lr_lambda = lambda count: 0.9
    lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer, lr_lambda=lr_lambda)
    lr_scheduler_enc = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer_enc, lr_lambda=lr_lambda)
    lr_scheduler_dec = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer_dec, lr_lambda=lr_lambda)

    # set beta KL scaling parameter
    if args.warmup == 0:
        beta_ten = torch.tensor(1.)
    else:
        beta_ten = torch.tensor(0.1)

    # set savae meta optimizer
    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(utils.variational_loss,
                              model,
                              update_params,
                              beta=beta_ten,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=1,
                              max_grad_norm=args.svi_max_grad_norm)

    # if test flag set, evaluate and exit
    if args.test == 1:
        beta_ten.data.fill_(1.)
        eval(test_loader, model, meta_optimizer, device)
        importance_sampling(data=test_loader,
                            model=model,
                            batch_size=args.batch_size,
                            meta_optimizer=meta_optimizer,
                            device=device,
                            nr_samples=20000,
                            test_mode=True,
                            verbose=True,
                            mode=args.test_type)
        exit()

    # initialize counters and stats
    epoch = 0
    t = 0
    best_val_metric = 100000000
    best_epoch = 0
    loss_stats = []
    # training loop
    C = torch.tensor(0., device=device)
    C_local = torch.zeros(args.batch_size * len(train_loader), device=device)
    epsilon = None
    step = 0
    while epoch < args.num_epochs:

        start_time = time.time()
        epoch += 1

        print('--------------------------------')
        print('starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_cdiv = 0.
        train_nll = 0.
        train_acc_rate = 0.
        num_examples = 0
        count_one_pixels = 0

        for b, datum in enumerate(train_loader):
            t += 1

            if args.warmup > 0:
                beta_ten.data.fill_(
                    torch.min(torch.tensor(1.), beta_ten + 1 /
                              (args.warmup * len(train_loader))).data)

            img, _ = datum
            img = torch.where(img < 0.5, torch.zeros_like(img),
                              torch.ones_like(img))
            if epoch == 1:
                count_one_pixels += torch.sum(img).item()
            img = img.to(device)

            optimizer.zero_grad()
            optimizer_enc.zero_grad()
            optimizer_dec.zero_grad()

            if args.model == 'svi':
                mean_svi = torch.zeros(args.batch_size,
                                       args.latent_dim,
                                       requires_grad=True,
                                       device=device)
                logvar_svi = torch.zeros(args.batch_size,
                                         args.latent_dim,
                                         requires_grad=True,
                                         device=device)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        img)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model.reparameterize(mean_svi_final.detach(),
                                                 logvar_svi_final.detach())
                preds = model.dec_forward(z_samples)
                nll_svi = utils.log_bernoulli_loss(preds, img)
                train_nll_svi += nll_svi.item() * args.batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.item() * args.batch_size
                var_loss = nll_svi + beta_ten.item() * kl_svi
                var_loss.backward()

                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()

            else:
                mean, logvar = model.enc_forward(img)
                z_samples = model.reparameterize(mean, logvar)
                preds = model.dec_forward(z_samples)
                nll_vae = utils.log_bernoulli_loss(preds, img)
                train_nll_vae += nll_vae.item() * args.batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.item() * args.batch_size

                if args.model == 'vae':
                    vae_loss = nll_vae + beta_ten.item() * kl_vae
                    vae_loss.backward()

                    optimizer.step()

                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = mean.clone().detach().requires_grad_(True)
                    logvar_svi = logvar.clone().detach().requires_grad_(True)

                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], img)
                    mean_svi_final, logvar_svi_final = var_params_svi

                    z_samples = model.reparameterize(mean_svi_final,
                                                     logvar_svi_final)
                    preds = model.dec_forward(z_samples)
                    nll_svi = utils.log_bernoulli_loss(preds, img)
                    train_nll_svi += nll_svi.item() * args.batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.item() * args.batch_size
                    var_loss = nll_svi + beta_ten.item() * kl_svi
                    var_loss.backward(retain_graph=True)
                    var_param_grads = meta_optimizer.backward(
                        [mean_svi_final.grad, logvar_svi_final.grad])
                    var_param_grads = torch.cat(var_param_grads, 1)
                    var_params.backward(var_param_grads)

                    if args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()

                if args.model == "cdiv" or args.model == "cdiv_svgd":

                    pxz = utils.log_pxz(preds, img, z_samples)
                    first_term = torch.mean(pxz) + 0.5 * args.latent_dim
                    logqz = utils.log_normal_pdf(z_samples, mean, logvar)

                    if epoch == 7 and b == 0:  # switch to local variate control
                        C_local = torch.ones(
                            args.batch_size * len(train_loader),
                            device=device) * C

                    if args.model == "cdiv":
                        zt, samples, acc_rate, epsilon = hmc.hmc_vae(
                            z_samples.clone().detach().requires_grad_(),
                            model,
                            img,
                            epsilon=epsilon,
                            Burn=0,
                            T=args.num_hmc_iters,
                            adapt=0,
                            L=5)
                        train_acc_rate += torch.mean(
                            acc_rate) * args.batch_size
                    else:
                        mean_all = torch.repeat_interleave(
                            mean, args.num_svgd_particles, 0)
                        logvar_all = torch.repeat_interleave(
                            logvar, args.num_svgd_particles, 0)
                        img_all = torch.repeat_interleave(
                            img, args.num_svgd_particles, 0)
                        z_samples = mean_all + torch.randn(
                            args.num_svgd_particles * args.batch_size,
                            args.latent_dim,
                            device=device) * torch.exp(0.5 * logvar_all)
                        samples = svgd.svgd_batched(args.num_svgd_particles,
                                                    args.batch_size,
                                                    z_samples,
                                                    model,
                                                    img_all.view(-1, 784),
                                                    iter=args.num_svgd_iters)
                        z_ind = torch.randint(low=0, high=args.num_svgd_particles, size=(args.batch_size,),
                                              device=device) + \
                                torch.tensor(args.num_svgd_particles, device=device) * \
                                torch.arange(0, args.batch_size, device=device)
                        zt = samples[z_ind]

                    preds_zt = model.dec_forward(zt)

                    pxzt = utils.log_pxz(preds_zt, img, zt)
                    g_zt = pxzt + torch.sum(
                        0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1)

                    second_term = torch.mean(g_zt)
                    cdiv = -first_term + second_term
                    train_cdiv += cdiv.item() * args.batch_size
                    train_nll += -torch.mean(pxzt).item() * args.batch_size

                    if epoch <= 6:
                        loss = -first_term + torch.mean(
                            torch.sum(
                                0.5 *
                                ((zt - mean)**2) * torch.exp(-logvar), 1) +
                            (g_zt.detach() - C) * logqz)
                        if b == 0:
                            C = torch.mean(g_zt.detach())
                        else:
                            C = 0.9 * C + 0.1 * torch.mean(g_zt.detach())
                    else:
                        control = C_local[b * args.batch_size:(b + 1) *
                                          args.batch_size]
                        loss = -first_term + torch.mean(
                            torch.sum(
                                0.5 *
                                ((zt - mean)**2) * torch.exp(-logvar), 1) +
                            (g_zt.detach() - control) * logqz)
                        C_local[b * args.batch_size:(b + 1) * args.batch_size] = \
                            0.9 * C_local[b * args.batch_size:(b + 1) * args.batch_size] + 0.1 * g_zt.detach()

                    loss.backward(retain_graph=True)
                    optimizer_enc.step()

                    optimizer_dec.zero_grad()
                    torch.mean(-utils.log_pxz(preds_zt, img, zt)).backward()
                    optimizer_dec.step()

            if t % 15000 == 0:
                if args.model == "cdiv" or args.model == "cdiv_svgd":
                    lr_scheduler_enc.step()
                    lr_scheduler_dec.step()
                else:
                    lr_scheduler.step()

            num_examples += args.batch_size
            if b and (b + 1) % args.print_every == 0:
                step += 1

                print('--------------------------------')
                print('iteration: %d, epoch: %d, batch: %d/%d' %
                      (t, epoch, b + 1, len(train_loader)))
                if epoch > 1:
                    print('best epoch: %d: %.2f' %
                          (best_epoch, best_val_metric))
                print('throughput: %.2f examples/sec' %
                      (num_examples / (time.time() - start_time)))

                if args.model != 'svi':
                    print(
                        'train_VAE_NLL: %.2f, train_VAE_KL: %.4f, train_VAE_NLLBnd: %.2f'
                        % (train_nll_vae / num_examples,
                           train_kl_vae / num_examples,
                           (train_nll_vae + train_kl_vae) / num_examples))
                    wandb.log(
                        {
                            "train_vae_nll":
                            train_nll_vae / num_examples,
                            "train_vae_kl":
                            train_kl_vae / num_examples,
                            "train_vae_nll_bound":
                            (train_nll_vae + train_kl_vae) / num_examples,
                        },
                        step=log_step)

                if args.model == 'svi' or args.model == 'savae':
                    print(
                        'train_SVI_NLL: %.2f, train_SVI_KL: %.4f, train_SVI_NLLBnd: %.2f'
                        % (train_nll_svi / num_examples,
                           train_kl_svi / num_examples,
                           (train_nll_svi + train_kl_svi) / num_examples))
                    wandb.log(
                        {
                            "train_svi_nll":
                            train_nll_svi / num_examples,
                            "train_svi_kl":
                            train_kl_svi / num_examples,
                            "train_svi_nll_bound":
                            (train_nll_svi + train_kl_svi) / num_examples,
                        },
                        step=log_step)

                if args.model == "cdiv" or args.model == "cdiv_svgd":
                    print(
                        'train_NLL: %.2f, train_CDIV: %.4f' %
                        (train_nll / num_examples, train_cdiv / num_examples))
                    wandb.log(
                        {
                            "train_nll": train_nll / num_examples,
                            "train_cdiv": train_cdiv / num_examples,
                        },
                        step=log_step)

                    if args.model == "cdiv":
                        print('train_average_acc_rate: %.3f' %
                              (train_acc_rate / num_examples))
                        wandb.log(
                            {
                                "train_average_acc_rate":
                                train_acc_rate / num_examples,
                            },
                            step=log_step)
                log_step += 1

        if epoch == 1:
            print('--------------------------------')
            print("count of pixels 1 in training data: {}".format(
                count_one_pixels))
            wandb.log({"dataset_pixel_check": count_one_pixels}, step=log_step)
        if args.no_validation:
            print('--------------------------------')
            print("[validation disabled!]")
        else:
            val_metric = eval(val_loader, model, meta_optimizer, device, epoch,
                              epsilon, log_step)

        checkpoint = {
            'args': args.__dict__,
            'model': model,
            'loss_stats': loss_stats
        }
        torch.save(checkpoint, args.checkpoint_path + "_last.pt")
        if not args.no_validation:
            loss_stats.append(val_metric)
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_epoch = epoch
                print('saving checkpoint to %s' %
                      (args.checkpoint_path + "_best.pt"))
                torch.save(checkpoint, args.checkpoint_path + "_best.pt")
Example #7
0
def eval(data,
         model,
         meta_optimizer,
         device,
         epoch=0,
         epsilon=None,
         log_step=0):
    print("********************************")
    print("validation epoch {}".format(epoch))

    num_examples = 0
    total_nll_vae = 0.
    total_kl_vae = 0.
    total_nll_svi = 0.
    total_kl_svi = 0.
    total_cdiv = 0.
    total_nll = 0.

    mean_llh = importance_sampling(data=data,
                                   model=model,
                                   batch_size=args.batch_size,
                                   meta_optimizer=meta_optimizer,
                                   device=device,
                                   nr_samples=10,
                                   test_mode=False,
                                   verbose=False,
                                   mode="vae",
                                   log_step=log_step)
    model.eval()

    for datum in data:
        img_pre, _ = datum
        batch_size = args.batch_size
        img = img_pre.to(device)
        img = torch.where(img < 0.5, torch.zeros_like(img),
                          torch.ones_like(img))

        if args.model == 'svi':
            mean_svi = 0.1 * torch.zeros(batch_size,
                                         model.latent_dim,
                                         device=device,
                                         requires_grad=True)
            logvar_svi = 0.1 * torch.zeros(batch_size,
                                           model.latent_dim,
                                           device=device,
                                           requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    img)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model.reparameterize(mean_svi_final.detach(),
                                             logvar_svi_final.detach())
            preds = model.dec_forward(z_samples)
            nll_svi = utils.log_bernoulli_loss(preds, img)
            total_nll_svi += nll_svi.item() * batch_size
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            total_kl_svi += kl_svi.item() * batch_size
        else:
            mean, logvar = model.enc_forward(img)
            z_samples = model.reparameterize(mean, logvar)
            preds = model.dec_forward(z_samples)
            nll_vae = utils.log_bernoulli_loss(preds, img)
            total_nll_vae += nll_vae.item() * batch_size
            kl_vae = utils.kl_loss_diag(mean, logvar)
            total_kl_vae += kl_vae.item() * batch_size
            if args.model == 'cdiv' or args.model == "cdiv_svgd":

                pxz = utils.log_pxz(preds, img, z_samples)
                first_term = torch.mean(pxz) + 0.5 * model.latent_dim

                if args.model == "cdiv":
                    zt, samples, acc_rate, epsilon = hmc.hmc_vae(
                        z_samples,
                        model,
                        img,
                        epsilon=epsilon,
                        Burn=0,
                        T=args.num_hmc_iters,
                        adapt=0,
                        L=5)
                else:
                    mean_all = torch.repeat_interleave(mean,
                                                       args.num_svgd_particles,
                                                       0)
                    logvar_all = torch.repeat_interleave(
                        logvar, args.num_svgd_particles, 0)
                    img_all = torch.repeat_interleave(img,
                                                      args.num_svgd_particles,
                                                      0)
                    z_samples = mean_all + torch.randn(
                        args.num_svgd_particles * args.batch_size,
                        args.latent_dim,
                        device=device) * torch.exp(0.5 * logvar_all)
                    samples = svgd.svgd_batched(args.num_svgd_particles,
                                                args.batch_size,
                                                z_samples,
                                                model,
                                                img_all.view(-1, 784),
                                                iter=args.num_svgd_iters)
                    z_ind = torch.randint(low=0, high=args.num_svgd_particles, size=(args.batch_size,),
                                          device=device) + \
                            torch.tensor(args.num_svgd_particles, device=device) * \
                            torch.arange(0, args.batch_size, device=device)
                    zt = samples[z_ind]

                preds_zt = model.dec_forward(zt)
                preds = preds_zt

                pxzt = utils.log_pxz(preds_zt, img, zt)
                g_zt = pxzt + torch.sum(
                    0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1)

                second_term = torch.mean(g_zt)
                cdiv = -first_term + second_term
                total_cdiv += cdiv.item() * batch_size
                total_nll += utils.log_bernoulli_loss(preds_zt,
                                                      img).item() * batch_size

            if args.model == 'savae':
                mean_svi = mean.data.clone().detach().requires_grad_(True)
                logvar_svi = logvar.data.clone().detach().requires_grad_(True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        img)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model.reparameterize(mean_svi_final,
                                                 logvar_svi_final)
                preds = model.dec_forward(z_samples.detach())
                nll_svi = utils.log_bernoulli_loss(preds, img)
                total_nll_svi += nll_svi.item() * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                total_kl_svi += kl_svi.item() * batch_size

        num_examples += batch_size

    n = min(img.size(0), 8)
    comparison = torch.cat([
        img_pre[:n],
        torch.sigmoid(preds.view(-1, 1, args.img_size[1],
                                 args.img_size[2])).cpu()[:n]
    ])

    example_images = wandb.Image(comparison,
                                 caption="images epoch {}".format(epoch))
    wandb.log({"validation images": example_images}, step=log_step)

    nll_vae = total_nll_vae / num_examples
    kl_vae = total_kl_vae / num_examples
    nll_bound_vae = (total_nll_vae + total_kl_vae) / num_examples

    nll_svi = total_nll_svi / num_examples
    kl_svi = total_kl_svi / num_examples
    nll_bound_svi = (total_nll_svi + total_kl_svi) / num_examples

    total_cdiv = total_cdiv / num_examples
    total_nll = total_nll / num_examples

    val_metric = -1

    if args.model != 'svi':
        print('val_VAE_NLL: %.2f, val_VAE_KL: %.4f, val_VAE_NLLBnd: %.2f' %
              (nll_vae, kl_vae, nll_bound_vae))
        wandb.log(
            {
                "val_vae_nll": nll_vae,
                "val_vae_kl": kl_vae,
                "val_vae_nll_bound": nll_bound_vae,
            },
            step=log_step)
        val_metric = nll_bound_vae
        if args.model == "vae":
            wandb.log({"bernoulli_loss": nll_vae}, step=log_step)

    if args.model == 'svi' or args.model == 'savae':
        print('val_SVI_NLL: %.2f, val_SVI_KL: %.4f, val_SVI_NLLBnd: %.2f' %
              (nll_svi, kl_svi, nll_bound_svi))
        wandb.log(
            {
                "val_svi_nll": nll_svi,
                "val_svi_kl": kl_svi,
                "val_svi_nll_bound": nll_bound_svi,
            },
            step=log_step)
        val_metric = nll_bound_svi
        wandb.log({"bernoulli_loss": nll_svi}, step=log_step)

    if args.model == "cdiv" or args.model == "cdiv_svgd":
        print('val_NLL: %.2f, val_CDIV: %.4f' % (total_nll, total_cdiv))
        wandb.log({
            "val_nll": total_nll,
            "val_cdiv": total_cdiv,
        },
                  step=log_step)
        val_metric = total_cdiv
        wandb.log({"bernoulli_loss": total_nll}, step=log_step)

    wandb.log({"val_metric": val_metric}, step=log_step)

    model.train()
    return val_metric
Example #8
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)
    test_data = Dataset(args.test_file)
    train_sents = train_data.batch_size.sum()
    vocab_size = int(train_data.vocab_size)
    logger.info('Train data: %d batches' % len(train_data))
    logger.info('Val data: %d batches' % len(val_data))
    logger.info('Test data: %d batches' % len(test_data))
    logger.info('Word vocab size: %d' % vocab_size)

    checkpoint_dir = args.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    suffix = "%s_%s.pt" % (args.model, 'cyc')
    checkpoint_path = os.path.join(checkpoint_dir, suffix)

    if args.slurm == 0:
        cuda.set_device(args.gpu)
    if args.train_from == '':
        model = RNNVAE(vocab_size=vocab_size,
                       enc_word_dim=args.enc_word_dim,
                       enc_h_dim=args.enc_h_dim,
                       enc_num_layers=args.enc_num_layers,
                       dec_word_dim=args.dec_word_dim,
                       dec_h_dim=args.dec_h_dim,
                       dec_num_layers=args.dec_num_layers,
                       dec_dropout=args.dec_dropout,
                       latent_dim=args.latent_dim,
                       mode=args.model)
        for param in model.parameters():
            param.data.uniform_(-0.1, 0.1)
    else:
        logger.info('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    logger.info("model architecture")
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = 0.1

    criterion = nn.NLLLoss()
    model.cuda()
    criterion.cuda()
    model.train()

    def variational_loss(input, sents, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(sents, z_samples)
        nll = sum([
            criterion(preds[:, l], sents[:, l + 1])
            for l in range(preds.size(1))
        ])
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    if args.test == 1:
        args.beta = 1
        test_data = Dataset(args.test_file)
        eval(test_data, model, meta_optimizer)
        exit()

    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    val_stats = []
    epoch = 0
    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        logger.info('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_kl_init_final = 0.
        num_sents = 0
        num_words = 0
        b = 0

        tmp = float((epoch - 1) % args.cycle) / args.cycle
        cur_lr = args.lr * 0.5 * (1 + np.cos(tmp * np.pi))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

        if (epoch - 1) % args.cycle == 0:
            args.beta = 0.1
            logger.info('KL annealing restart')

        for i in np.random.permutation(len(train_data)):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_data)))

            sents, length, batch_size = train_data[i]
            if args.gpu >= 0:
                sents = sents.cuda()
            b += 1

            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(sents, None, True)
                nll_autoreg = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_autoreg += nll_autoreg.data[0] * batch_size
                nll_autoreg.backward()
            elif args.model == 'svi':
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], sents, b % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(sents, z_samples)
                nll_svi = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward(retain_graph=True)
            else:
                mean, logvar = model._enc_forward(sents)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(sents, z_samples)
                nll_vae = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_vae += nll_vae.data[0] * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.data[0] * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)
                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)
                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], sents,
                        b % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(sents, z_samples)
                    nll_svi = sum([
                        criterion(preds[:, l], sents[:, l + 1])
                        for l in range(length)
                    ])
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            train_kl_init_final += kl_init_final.data[
                                0] * batch_size
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            b % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_sents += batch_size
            num_words += batch_size * length

            if b % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                logger.info(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.4f, TrainARPPL: %.2f, TrainVAE_NLL: %.4f, TrainVAE_REC: %.4f, TrainVAE_KL: %.4f, TrainVAE_PPL: %.2f, TrainSVI_NLL: %.2f, TrainSVI_REC: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPL: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec'
                    % (t, epoch, b + 1, len(train_data), cur_lr,
                       train_nll_autoreg / num_sents,
                       np.exp(train_nll_autoreg / num_words),
                       (train_nll_vae + train_kl_vae) / num_sents,
                       train_nll_vae / num_sents, train_kl_vae / num_sents,
                       np.exp((train_nll_vae + train_kl_vae) / num_words),
                       (train_nll_svi + train_kl_svi) / num_sents,
                       train_nll_svi / num_sents, train_kl_svi / num_sents,
                       np.exp((train_nll_svi + train_kl_svi) / num_words),
                       train_kl_init_final / num_sents, param_norm,
                       best_val_nll, best_epoch, args.beta, num_sents /
                       (time.time() - start_time)))

        epoch_train_time = time.time() - start_time
        logger.info('Time Elapsed: %.1fs' % epoch_train_time)

        logger.info('--------------------------------')
        logger.info('Checking validation perf...')
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Mode', 'Val')
        logger.record_tabular('LR', cur_lr)
        logger.record_tabular('Epoch Train Time', epoch_train_time)
        val_nll = eval(val_data, model, meta_optimizer)
        val_stats.append(val_nll)

        logger.info('--------------------------------')
        logger.info('Checking test perf...')
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Mode', 'Test')
        logger.record_tabular('LR', cur_lr)
        logger.record_tabular('Epoch Train Time', epoch_train_time)
        test_nll = eval(test_data, model, meta_optimizer)

        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_epoch = epoch
            model.cpu()
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'val_stats': val_stats
            }
            logger.info('Save checkpoint to %s' % checkpoint_path)
            torch.save(checkpoint, checkpoint_path)
            model.cuda()
        else:
            if epoch >= args.min_epochs:
                args.decay = 1
Example #9
0
def eval(data, model, meta_optimizer):

    model.eval()
    criterion = nn.NLLLoss().cuda()
    num_sents = 0
    num_words = 0
    total_nll_autoreg = 0.
    total_nll_vae = 0.
    total_kl_vae = 0.
    total_nll_svi = 0.
    total_kl_svi = 0.
    best_svi_loss = 0.
    for i in range(len(data)):
        sents, length, batch_size = data[i]
        num_words += batch_size * length
        num_sents += batch_size
        if args.gpu >= 0:
            sents = sents.cuda()
        if args.model == 'autoreg':
            preds = model._dec_forward(sents, None, True)
            nll_autoreg = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_autoreg += nll_autoreg.data[0] * batch_size
        elif args.model == 'svi':
            mean_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            logvar_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    sents)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model._reparameterize(mean_svi_final.detach(),
                                              logvar_svi_final.detach())
            preds = model._dec_forward(sents, z_samples)
            nll_svi = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_svi += nll_svi.data[0] * batch_size
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            total_kl_svi += kl_svi.data[0] * batch_size
            mean, logvar = mean_svi_final, logvar_svi_final
        else:
            mean, logvar = model._enc_forward(sents)
            z_samples = model._reparameterize(mean, logvar)
            preds = model._dec_forward(sents, z_samples)
            nll_vae = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_vae += nll_vae.data[0] * batch_size
            kl_vae = utils.kl_loss_diag(mean, logvar)
            total_kl_vae += kl_vae.data[0] * batch_size
            if args.model == 'savae':
                mean_svi = Variable(mean.data, requires_grad=True)
                logvar_svi = Variable(logvar.data, requires_grad=True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        sents)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final,
                                                  logvar_svi_final)
                preds = model._dec_forward(sents, z_samples)
                nll_svi = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                total_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                total_kl_svi += kl_svi.data[0] * batch_size
                mean, logvar = mean_svi_final, logvar_svi_final

    nll_autoreg = total_nll_autoreg / num_sents
    ppl_autoreg = np.exp(total_nll_autoreg / num_words)
    nll_vae = (total_nll_vae + total_kl_vae) / num_sents
    rec_vae = total_nll_vae / num_sents
    kl_vae = total_kl_vae / num_sents
    ppl_bound_vae = np.exp((total_nll_vae + total_kl_vae) / num_words)
    nll_svi = (total_nll_svi + total_kl_svi) / num_sents
    rec_svi = total_nll_svi / num_sents
    kl_svi = total_kl_svi / num_sents
    ppl_bound_svi = np.exp((total_nll_svi + total_kl_svi) / num_words)

    logger.record_tabular('AR NLL', nll_autoreg)
    logger.record_tabular('AR PPL', ppl_autoreg)
    logger.record_tabular('VAE NLL', nll_vae)
    logger.record_tabular('VAE REC', rec_vae)
    logger.record_tabular('VAE KL', kl_vae)
    logger.record_tabular('VAE PPL', ppl_bound_vae)
    logger.record_tabular('SVI NLL', nll_svi)
    logger.record_tabular('SVI REC', rec_svi)
    logger.record_tabular('SVI KL', kl_svi)
    logger.record_tabular('SVI PPL', ppl_bound_svi)
    logger.dump_tabular()
    logger.info(
        'AR NLL: %.4f, AR PPL: %.4f, VAE NLL: %.4f, VAE REC: %.4f, VAE KL: %.4f, VAE PPL: %.4f, SVI NLL: %.4f, SVI REC: %.4f, SVI KL: %.4f, SVI PPL: %.4f'
        % (nll_autoreg, ppl_autoreg, nll_vae, rec_vae, kl_vae, ppl_bound_vae,
           nll_svi, rec_svi, kl_svi, ppl_bound_svi))
    model.train()
    if args.model == 'autoreg':
        return ppl_autoreg
    elif args.model == 'vae':
        return ppl_bound_vae
    elif args.model == 'savae' or args.model == 'svi':
        return ppl_bound_svi
Example #10
0
def eval(data, model, meta_optimizer, agg_kl=0):
    model.eval()
    criterion = nn.NLLLoss().cuda()
    num_sents = 0
    num_words = 0
    total_nll_autoreg = 0.
    total_nll_vae = 0.
    total_kl_vae = 0.
    total_nll_svi = 0.
    total_kl_svi = 0.
    best_svi_loss = 0.
    total_kl_dim = 0
    for i in range(len(data)):
        sents, length, batch_size = data[i]
        num_words += batch_size * length
        num_sents += batch_size
        if args.gpu >= 0:
            sents = sents.cuda()
        if args.model == 'autoreg':
            preds = model._dec_forward(sents, None, True)
            nll_autoreg = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_autoreg += nll_autoreg.data[0] * batch_size
        elif args.model == 'svi':
            mean_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            logvar_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).cuda(),
                requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    sents)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model._reparameterize(mean_svi_final.detach(),
                                              logvar_svi_final.detach())
            preds = model._dec_forward(sents, z_samples)
            nll_svi = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_svi += nll_svi.data[0] * batch_size
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            total_kl_svi += kl_svi.data[0] * batch_size
            mean, logvar = mean_svi_final, logvar_svi_final
        else:
            mean, logvar = model._enc_forward(sents)
            z_samples = model._reparameterize(mean, logvar)
            preds = model._dec_forward(sents, z_samples)
            nll_vae = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_vae += nll_vae.data[0] * batch_size
            kl_vae = utils.kl_loss_diag(mean, logvar)
            kl_dim = utils.kl_loss_dim(mean, logvar)
            total_kl_dim += kl_dim.sum(0).data
            total_kl_vae += kl_vae.data[0] * batch_size
            if args.model == 'savae':
                mean_svi = Variable(mean.data, requires_grad=True)
                logvar_svi = Variable(logvar.data, requires_grad=True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        sents)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final,
                                                  logvar_svi_final)
                preds = model._dec_forward(sents, z_samples)
                nll_svi = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                total_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                total_kl_svi += kl_svi.data[0] * batch_size
                mean, logvar = mean_svi_final, logvar_svi_final
    ppl_autoreg = np.exp(total_nll_autoreg / num_words)
    ppl_vae = np.exp(total_nll_vae / num_words)
    kl_vae = total_kl_vae / num_sents
    ppl_bound_vae = np.exp((total_nll_vae + total_kl_vae) / num_words)
    ppl_svi = np.exp(total_nll_svi / num_words)
    kl_svi = total_kl_svi / num_sents
    kl_dim = total_kl_dim / num_sents
    print(['%.4f' % e for e in list(kl_dim)])
    ppl_bound_svi = np.exp((total_nll_svi + total_kl_svi) / num_words)
    print('elbo vae', (total_nll_vae + total_kl_vae) / num_sents)
    print('elbo savi', (total_nll_svi + total_kl_svi) / num_sents)

    print(
        'AR PPL: %.4f, VAE PPL: %.4f, VAE KL: %.4f, VAE PPL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI PPL BOUND: %.4f'
        % (ppl_autoreg, ppl_vae, kl_vae, ppl_bound_vae, ppl_svi, kl_svi,
           ppl_bound_svi))
    model.train()
    if args.model == 'autoreg':
        return ppl_autoreg
    elif args.model == 'vae':
        return ppl_bound_vae
    elif args.model == 'savae' or args.model == 'svi':
        return ppl_bound_svi
Example #11
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)
    train_sents = train_data.batch_size.sum()
    vocab_size = int(train_data.vocab_size)
    print('Train data: %d batches' % len(train_data))
    print('Val data: %d batches' % len(val_data))
    print('Word vocab size: %d' % vocab_size)
    if args.slurm == 0:
        # cuda.set_device(args.gpu)
        gpu_id = 0
        device = torch.device(
            f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    if args.train_from == '':
        model = RNNVAE(vocab_size=vocab_size,
                       enc_word_dim=args.enc_word_dim,
                       enc_h_dim=args.enc_h_dim,
                       enc_num_layers=args.enc_num_layers,
                       dec_word_dim=args.dec_word_dim,
                       dec_h_dim=args.dec_h_dim,
                       dec_num_layers=args.dec_num_layers,
                       dec_dropout=args.dec_dropout,
                       latent_dim=args.latent_dim,
                       mode=args.model)
        for param in model.parameters():
            param.data.uniform_(-0.1, 0.1)
    else:
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    print("model architecture")
    print(model)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = args.kl_start

    criterion = nn.NLLLoss(reduce=False)
    # criterion = nn.NLLLoss()
    # model.cuda()
    # criterion.cuda()
    # model = torch.nn.DataParallel(net, device_ids=[0, 1])
    model.to(device)
    criterion.to(device)
    model.train()

    def variational_loss(input, sents, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(sents, z_samples)
        nll = sum([
            criterion(preds[:, l], sents[:, l + 1])
            for l in range(preds.size(1))
        ])
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    if args.test == 1:
        args.beta = 1
        test_data = Dataset(args.test_file)
        eval(args, test_data, model, meta_optimizer, device)
        exit()

    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    val_stats = []
    epoch = 0
    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        print('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_kl_init_final = 0.
        num_sents = 0
        num_words = 0
        b = 0

        for i in np.random.permutation(len(train_data)):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_data)))

            sents, length, batch_size = train_data[i]
            length = length.item()
            batch_size = batch_size.item()

            if args.gpu >= 0:
                # sents = sents.cuda()
                sents = sents.to(device)
                # batch_size = batch_size.to(device)
            b += 1

            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(sents, None, True)
                tgt = sents[:, 1:].contiguous()
                nll_autoreg = criterion(preds.view(-1, preds.size(2)),
                                        tgt.view(-1)).view(preds.size(0),
                                                           -1).sum(-1).mean(0)
                # nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                train_nll_autoreg += nll_autoreg.item() * batch_size
                # train_nll_autoreg += nll_autoreg.data[0]*batch_size #old
                nll_autoreg.backward()
            elif args.model == 'svi':
                # mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True)
                # logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True)
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).to(device),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).to(device),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], sents, b % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(sents, z_samples)
                tgt = sents[:, 1:].contiguous()
                nll_svi = criterion(preds.view(-1, preds.size(2)),
                                    tgt.view(-1)).view(preds.size(0),
                                                       -1).sum(-1).mean(0)
                # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward(retain_graph=True)
            else:
                mean, logvar = model._enc_forward(sents)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(sents, z_samples)
                tgt = sents[:, 1:].contiguous()
                nll_vae = criterion(preds.view(-1, preds.size(2)),
                                    tgt.view(-1)).view(preds.size(0),
                                                       -1).sum(-1).mean(0)
                # nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                # train_nll_vae += nll_vae.data[0]*batch_size#old
                train_nll_vae += nll_vae.item() * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                # train_kl_vae += kl_vae.data[0]*batch_size#old
                train_kl_vae += kl_vae.item() * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)
                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)
                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], sents,
                        b % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(sents, z_samples)
                    tgt = sents[:, 1:].contiguous()
                    nll_svi = criterion(preds.view(-1, preds.size(2)),
                                        tgt.view(-1)).view(preds.size(0),
                                                           -1).sum(-1).mean(0)
                    # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            train_kl_init_final += kl_init_final.data[
                                0] * batch_size
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            b % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_sents += batch_size
            num_words += batch_size * length
            # num_sents = num_sents.item()
            # num_words = num_words.item()
            if b % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                print(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARPPL: %.2f, TrainVAE_PPL: %.2f, TrainVAE_KL: %.4f, TrainVAE_PPLBnd: %.2f, TrainSVI_PPL: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPLBnd: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec'
                    %
                    (t, epoch, b + 1, len(train_data), args.lr,
                     np.exp(train_nll_autoreg / num_words),
                     np.exp(
                         train_nll_vae / num_words), train_kl_vae / num_sents,
                     np.exp((train_nll_vae + train_kl_vae) / num_words),
                     np.exp(
                         train_nll_svi / num_words), train_kl_svi / num_sents,
                     np.exp((train_nll_svi + train_kl_svi) / num_words),
                     train_kl_init_final / num_sents, param_norm, best_val_nll,
                     best_epoch, args.beta, num_sents /
                     (time.time() - start_time)))

        print('--------------------------------')
        print('Checking validation perf...')
        val_nll = eval(args, val_data, model, meta_optimizer, device)
        val_stats.append(val_nll)

        # if val_elbo > self.best_val_elbo:
        #     self.not_improved = 0
        #     self.best_val_elbo = val_elbo
        # else:
        #     self.not_improved += 1
        #     if self.not_improved % 5 == 0:
        #         self.current_lr = self.current_lr * self.config.options.lr_decay
        #         print(f'New LR {self.current_lr}')
        #         model.optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)
        #         model.enc_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)
        #         model.dec_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)

        if val_nll < best_val_nll:
            not_improved = 0
            best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll)
            if os.path.exists(best_save):
                os.remove(best_save)

            best_val_nll = val_nll
            best_epoch = epoch
            model.cpu()
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'val_stats': val_stats
            }
            print('Savaeng checkpoint to %s' % args.checkpoint_path)
            best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll)
            torch.save(checkpoint, best_save)

            # model.cuda()
            model.to(device)
        else:
            not_improved += 1
            if not_improved % 5 == 0:
                not_improved = 0
                args.lr = args.lr * args.lr_decay
                print(f'New LR: {args.lr}')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr
Example #12
0
def eval(args, data, model, meta_optimizer, device):
    model.dropout.eval()
    model.dec_linear[0].eval()
    # model.eval()
    # print(model.dropout.training)
    # print(model.dec_linear[0].training)

    # criterion = nn.NLLLoss().cuda()
    criterion = nn.NLLLoss().to(device)
    if args.model == 'vae':
        calc_iw(args, data, model, meta_optimizer, criterion, device)
    num_sents = 0
    num_words = 0
    total_nll_autoreg = 0.
    total_nll_vae = 0.
    total_kl_vae = 0.
    total_nll_svi = 0.
    total_kl_svi = 0.
    best_svi_loss = 0.
    for i in range(len(data)):
        sents, length, batch_size = data[i]
        length = length.item()
        batch_size = batch_size.item()
        num_words += batch_size * length
        num_sents += batch_size
        if args.gpu >= 0:
            # sents = sents.cuda()
            sents = sents.to(device)
        if args.model == 'autoreg':
            preds = model._dec_forward(sents, None, True)
            nll_autoreg = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            # total_nll_autoreg += nll_autoreg.data[0]*batch_size #old
            total_nll_autoreg += nll_autoreg.item() * batch_size
        elif args.model == 'svi':
            # mean_svi = Variable(0.1*torch.randn(batch_size, args.latent_dim).cuda(), requires_grad = True)
            # logvar_svi = Variable(0.1*torch.randn(batch_size, args.latent_dim).cuda(), requires_grad = True)
            mean_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).to(device),
                requires_grad=True)
            logvar_svi = Variable(
                0.1 * torch.randn(batch_size, args.latent_dim).to(device),
                requires_grad=True)
            var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                    sents)
            mean_svi_final, logvar_svi_final = var_params_svi
            z_samples = model._reparameterize(mean_svi_final.detach(),
                                              logvar_svi_final.detach())
            preds = model._dec_forward(sents, z_samples)
            nll_svi = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            total_nll_svi += nll_svi.data[0] * batch_size
            kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
            total_kl_svi += kl_svi.data[0] * batch_size
            mean, logvar = mean_svi_final, logvar_svi_final
        else:
            mean, logvar = model._enc_forward(sents)
            z_samples = model._reparameterize(mean, logvar)
            preds = model._dec_forward(sents, z_samples)
            nll_vae = sum([
                criterion(preds[:, l], sents[:, l + 1]) for l in range(length)
            ])
            # total_nll_vae += nll_vae.data[0]*batch_size#old
            total_nll_vae += nll_vae.item() * batch_size
            kl_vae = utils.kl_loss_diag(mean, logvar)
            # total_kl_vae += kl_vae.data[0]*batch_size#old
            total_kl_vae += kl_vae.item() * batch_size
            if args.model == 'savae':
                mean_svi = Variable(mean.data, requires_grad=True)
                logvar_svi = Variable(logvar.data, requires_grad=True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        sents)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final,
                                                  logvar_svi_final)
                preds = model._dec_forward(sents, z_samples)
                nll_svi = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                total_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                total_kl_svi += kl_svi.data[0] * batch_size
                mean, logvar = mean_svi_final, logvar_svi_final
    # num_words = num_words.item()
    # num_sents = num_sents.item()
    ppl_autoreg = np.exp(total_nll_autoreg / num_words)
    ppl_vae = np.exp(total_nll_vae / num_words)
    kl_vae = total_kl_vae / num_sents
    ppl_bound_vae = np.exp((total_nll_vae + total_kl_vae) / num_words)
    ppl_svi = np.exp(total_nll_svi / num_words)
    kl_svi = total_kl_svi / num_sents
    ppl_bound_svi = np.exp((total_nll_svi + total_kl_svi) / num_words)
    print("num_words", num_words)
    print("num_sents", num_sents)

    if args.test == 1:
        f = open(args.checkpoint_path + '_log_test', 'a')
    else:
        f = open(args.checkpoint_path + '_log_val', 'a')
    eval_line = 'AR PPL: %.4f, VAE PPL: %.4f, VAE KL: %.4f, VAE PPL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI PPL BOUND: %.4f\n' % (
        ppl_autoreg, ppl_vae, kl_vae, ppl_bound_vae, ppl_svi, kl_svi,
        ppl_bound_svi)
    f.write(eval_line)

    print(
        'AR PPL: %.4f, VAE PPL: %.4f, VAE KL: %.4f, VAE PPL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI PPL BOUND: %.4f'
        % (ppl_autoreg, ppl_vae, kl_vae, ppl_bound_vae, ppl_svi, kl_svi,
           ppl_bound_svi))
    model.train()
    if args.model == 'autoreg':
        return ppl_autoreg
    elif args.model == 'vae':
        return ppl_bound_vae
    elif args.model == 'savae' or args.model == 'svi':
        return ppl_bound_svi