def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) all_data = torch.load(args.data_file) x_train, x_val, x_test = all_data y_size = 1 y_train = torch.zeros(x_train.size(0), y_size) y_val = torch.zeros(x_val.size(0), y_size) y_test = torch.zeros(x_test.size(0), y_size) train = torch.utils.data.TensorDataset(x_train, y_train) val = torch.utils.data.TensorDataset(x_val, y_val) test = torch.utils.data.TensorDataset(x_test, y_test) train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test, batch_size=args.batch_size, shuffle=True) print('Train data: %d batches' % len(train_loader)) print('Val data: %d batches' % len(val_loader)) print('Test data: %d batches' % len(test_loader)) if args.slurm == 0: cuda.set_device(args.gpu) if args.model == 'autoreg': args.latent_feature_map = 0 if args.train_from == '': model = CNNVAE(img_size=args.img_size, latent_dim=args.latent_dim, enc_layers=args.enc_layers, dec_kernel_size=args.dec_kernel_size, dec_layers=args.dec_layers, latent_feature_map=args.latent_feature_map) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) model.cuda() model.train() def variational_loss(input, img, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(img, z_samples) nll = utils.log_bernoulli_loss(preds, img) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) epoch = 0 t = 0 best_val_nll = 1e5 best_epoch = 0 loss_stats = [] if args.warmup == 0: args.beta = 1. else: args.beta = 0.1 if args.test == 1: args.beta = 1 eval(test_loader, model, meta_optimizer) exit() while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. num_examples = 0 for b, datum in enumerate(train_loader): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_loader))) img, _ = datum img = torch.bernoulli(img) batch_size = img.size(0) img = Variable(img.cuda()) t += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(img, None) nll_autoreg = utils.log_bernoulli_loss(preds, img) train_nll_autoreg += nll_autoreg.data[0] * batch_size nll_autoreg.backward() elif args.model == 'svi': mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img, t % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(img, z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward() else: mean, logvar = model._enc_forward(img) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(img, z_samples) nll_vae = utils.log_bernoulli_loss(preds, img) train_nll_vae += nll_vae.data[0] * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.data[0] * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img, t % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(img, z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], t % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_examples += batch_size if t % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 print( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.2f, TrainVAE_NLL: %.2f, TrainVAE_KL: %.4f, TrainVAE_NLLBnd: %.2f, TrainSVI_NLL: %.2f, TrainSVI_KL: %.4f, TrainSVI_NLLBnd: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.3f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_loader), args.lr, train_nll_autoreg / num_examples, train_nll_vae / num_examples, train_kl_vae / num_examples, (train_nll_vae + train_kl_vae) / num_examples, train_nll_svi / num_examples, train_kl_svi / num_examples, (train_nll_svi + train_kl_svi) / num_examples, param_norm, best_val_nll, best_epoch, args.beta, num_examples / (time.time() - start_time))) print('--------------------------------') print('Checking validation perf...') val_nll = eval(val_loader, model, meta_optimizer) loss_stats.append(val_nll) if val_nll < best_val_nll: best_val_nll = val_nll best_epoch = epoch checkpoint = { 'args': args.__dict__, 'model': model, 'optimizer': optimizer, 'loss_stats': loss_stats } print('Savaeng checkpoint to %s' % args.checkpoint_path) torch.save(checkpoint, args.checkpoint_path)
def main(): wandb.init(project="vae-comparison") wandb.config.update(args) log_step = 0 # set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # set device use_gpu = args.use_gpu and torch.cuda.is_available() device = torch.device("cuda" if use_gpu else "cpu") print("training on {} device".format("cuda" if use_gpu else "cpu")) # load dataset train_loader, val_loader, test_loader = load_data( dataset=args.dataset, batch_size=args.batch_size, no_validation=args.no_validation, shuffle=args.shuffle, data_file=args.data_file) # define model or load checkpoint if args.train_from == '': print('--------------------------------') print("initializing new model") model = VAE(latent_dim=args.latent_dim) else: print('--------------------------------') print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print('--------------------------------') print("model architecture") print(model) # set model for training model.to(device) model.train() # define optimizers and their schedulers optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer_enc = torch.optim.Adam(model.enc.parameters(), lr=args.lr) optimizer_dec = torch.optim.Adam(model.dec.parameters(), lr=args.lr) lr_lambda = lambda count: 0.9 lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lr_lambda=lr_lambda) lr_scheduler_enc = torch.optim.lr_scheduler.MultiplicativeLR( optimizer_enc, lr_lambda=lr_lambda) lr_scheduler_dec = torch.optim.lr_scheduler.MultiplicativeLR( optimizer_dec, lr_lambda=lr_lambda) # set beta KL scaling parameter if args.warmup == 0: beta_ten = torch.tensor(1.) else: beta_ten = torch.tensor(0.1) # set savae meta optimizer update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(utils.variational_loss, model, update_params, beta=beta_ten, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=1, max_grad_norm=args.svi_max_grad_norm) # if test flag set, evaluate and exit if args.test == 1: beta_ten.data.fill_(1.) eval(test_loader, model, meta_optimizer, device) importance_sampling(data=test_loader, model=model, batch_size=args.batch_size, meta_optimizer=meta_optimizer, device=device, nr_samples=20000, test_mode=True, verbose=True, mode=args.test_type) exit() # initialize counters and stats epoch = 0 t = 0 best_val_metric = 100000000 best_epoch = 0 loss_stats = [] # training loop C = torch.tensor(0., device=device) C_local = torch.zeros(args.batch_size * len(train_loader), device=device) epsilon = None step = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('--------------------------------') print('starting epoch %d' % epoch) train_nll_vae = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_cdiv = 0. train_nll = 0. train_acc_rate = 0. num_examples = 0 count_one_pixels = 0 for b, datum in enumerate(train_loader): t += 1 if args.warmup > 0: beta_ten.data.fill_( torch.min(torch.tensor(1.), beta_ten + 1 / (args.warmup * len(train_loader))).data) img, _ = datum img = torch.where(img < 0.5, torch.zeros_like(img), torch.ones_like(img)) if epoch == 1: count_one_pixels += torch.sum(img).item() img = img.to(device) optimizer.zero_grad() optimizer_enc.zero_grad() optimizer_dec.zero_grad() if args.model == 'svi': mean_svi = torch.zeros(args.batch_size, args.latent_dim, requires_grad=True, device=device) logvar_svi = torch.zeros(args.batch_size, args.latent_dim, requires_grad=True, device=device) var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model.dec_forward(z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.item() * args.batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.item() * args.batch_size var_loss = nll_svi + beta_ten.item() * kl_svi var_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() else: mean, logvar = model.enc_forward(img) z_samples = model.reparameterize(mean, logvar) preds = model.dec_forward(z_samples) nll_vae = utils.log_bernoulli_loss(preds, img) train_nll_vae += nll_vae.item() * args.batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.item() * args.batch_size if args.model == 'vae': vae_loss = nll_vae + beta_ten.item() * kl_vae vae_loss.backward() optimizer.step() if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = mean.clone().detach().requires_grad_(True) logvar_svi = logvar.clone().detach().requires_grad_(True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize(mean_svi_final, logvar_svi_final) preds = model.dec_forward(z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.item() * args.batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.item() * args.batch_size var_loss = nll_svi + beta_ten.item() * kl_svi var_loss.backward(retain_graph=True) var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad]) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.model == "cdiv" or args.model == "cdiv_svgd": pxz = utils.log_pxz(preds, img, z_samples) first_term = torch.mean(pxz) + 0.5 * args.latent_dim logqz = utils.log_normal_pdf(z_samples, mean, logvar) if epoch == 7 and b == 0: # switch to local variate control C_local = torch.ones( args.batch_size * len(train_loader), device=device) * C if args.model == "cdiv": zt, samples, acc_rate, epsilon = hmc.hmc_vae( z_samples.clone().detach().requires_grad_(), model, img, epsilon=epsilon, Burn=0, T=args.num_hmc_iters, adapt=0, L=5) train_acc_rate += torch.mean( acc_rate) * args.batch_size else: mean_all = torch.repeat_interleave( mean, args.num_svgd_particles, 0) logvar_all = torch.repeat_interleave( logvar, args.num_svgd_particles, 0) img_all = torch.repeat_interleave( img, args.num_svgd_particles, 0) z_samples = mean_all + torch.randn( args.num_svgd_particles * args.batch_size, args.latent_dim, device=device) * torch.exp(0.5 * logvar_all) samples = svgd.svgd_batched(args.num_svgd_particles, args.batch_size, z_samples, model, img_all.view(-1, 784), iter=args.num_svgd_iters) z_ind = torch.randint(low=0, high=args.num_svgd_particles, size=(args.batch_size,), device=device) + \ torch.tensor(args.num_svgd_particles, device=device) * \ torch.arange(0, args.batch_size, device=device) zt = samples[z_ind] preds_zt = model.dec_forward(zt) pxzt = utils.log_pxz(preds_zt, img, zt) g_zt = pxzt + torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) second_term = torch.mean(g_zt) cdiv = -first_term + second_term train_cdiv += cdiv.item() * args.batch_size train_nll += -torch.mean(pxzt).item() * args.batch_size if epoch <= 6: loss = -first_term + torch.mean( torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) + (g_zt.detach() - C) * logqz) if b == 0: C = torch.mean(g_zt.detach()) else: C = 0.9 * C + 0.1 * torch.mean(g_zt.detach()) else: control = C_local[b * args.batch_size:(b + 1) * args.batch_size] loss = -first_term + torch.mean( torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) + (g_zt.detach() - control) * logqz) C_local[b * args.batch_size:(b + 1) * args.batch_size] = \ 0.9 * C_local[b * args.batch_size:(b + 1) * args.batch_size] + 0.1 * g_zt.detach() loss.backward(retain_graph=True) optimizer_enc.step() optimizer_dec.zero_grad() torch.mean(-utils.log_pxz(preds_zt, img, zt)).backward() optimizer_dec.step() if t % 15000 == 0: if args.model == "cdiv" or args.model == "cdiv_svgd": lr_scheduler_enc.step() lr_scheduler_dec.step() else: lr_scheduler.step() num_examples += args.batch_size if b and (b + 1) % args.print_every == 0: step += 1 print('--------------------------------') print('iteration: %d, epoch: %d, batch: %d/%d' % (t, epoch, b + 1, len(train_loader))) if epoch > 1: print('best epoch: %d: %.2f' % (best_epoch, best_val_metric)) print('throughput: %.2f examples/sec' % (num_examples / (time.time() - start_time))) if args.model != 'svi': print( 'train_VAE_NLL: %.2f, train_VAE_KL: %.4f, train_VAE_NLLBnd: %.2f' % (train_nll_vae / num_examples, train_kl_vae / num_examples, (train_nll_vae + train_kl_vae) / num_examples)) wandb.log( { "train_vae_nll": train_nll_vae / num_examples, "train_vae_kl": train_kl_vae / num_examples, "train_vae_nll_bound": (train_nll_vae + train_kl_vae) / num_examples, }, step=log_step) if args.model == 'svi' or args.model == 'savae': print( 'train_SVI_NLL: %.2f, train_SVI_KL: %.4f, train_SVI_NLLBnd: %.2f' % (train_nll_svi / num_examples, train_kl_svi / num_examples, (train_nll_svi + train_kl_svi) / num_examples)) wandb.log( { "train_svi_nll": train_nll_svi / num_examples, "train_svi_kl": train_kl_svi / num_examples, "train_svi_nll_bound": (train_nll_svi + train_kl_svi) / num_examples, }, step=log_step) if args.model == "cdiv" or args.model == "cdiv_svgd": print( 'train_NLL: %.2f, train_CDIV: %.4f' % (train_nll / num_examples, train_cdiv / num_examples)) wandb.log( { "train_nll": train_nll / num_examples, "train_cdiv": train_cdiv / num_examples, }, step=log_step) if args.model == "cdiv": print('train_average_acc_rate: %.3f' % (train_acc_rate / num_examples)) wandb.log( { "train_average_acc_rate": train_acc_rate / num_examples, }, step=log_step) log_step += 1 if epoch == 1: print('--------------------------------') print("count of pixels 1 in training data: {}".format( count_one_pixels)) wandb.log({"dataset_pixel_check": count_one_pixels}, step=log_step) if args.no_validation: print('--------------------------------') print("[validation disabled!]") else: val_metric = eval(val_loader, model, meta_optimizer, device, epoch, epsilon, log_step) checkpoint = { 'args': args.__dict__, 'model': model, 'loss_stats': loss_stats } torch.save(checkpoint, args.checkpoint_path + "_last.pt") if not args.no_validation: loss_stats.append(val_metric) if val_metric < best_val_metric: best_val_metric = val_metric best_epoch = epoch print('saving checkpoint to %s' % (args.checkpoint_path + "_best.pt")) torch.save(checkpoint, args.checkpoint_path + "_best.pt")
mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(sents, z_samples) nll = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(preds.size(1)) ]) kl = utils.kl_loss_diag(mean, logvar) return nll + beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, beta, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=True, max_grad_norm=args.svi_max_grad_norm) def evaluation(data, model, meta_optimizer): model.dec_linear.eval() model.dropout.eval() meta_optimizer.beta = 1.0 num_sents = 0.0 num_words = 0.0 total_rec = 0.0
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) test_data = Dataset(args.test_file) train_sents = train_data.batch_size.sum() vocab_size = int(train_data.vocab_size) logger.info('Train data: %d batches' % len(train_data)) logger.info('Val data: %d batches' % len(val_data)) logger.info('Test data: %d batches' % len(test_data)) logger.info('Word vocab size: %d' % vocab_size) checkpoint_dir = args.checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) suffix = "%s_%s.pt" % (args.model, 'cyc') checkpoint_path = os.path.join(checkpoint_dir, suffix) if args.slurm == 0: cuda.set_device(args.gpu) if args.train_from == '': model = RNNVAE(vocab_size=vocab_size, enc_word_dim=args.enc_word_dim, enc_h_dim=args.enc_h_dim, enc_num_layers=args.enc_num_layers, dec_word_dim=args.dec_word_dim, dec_h_dim=args.dec_h_dim, dec_num_layers=args.dec_num_layers, dec_dropout=args.dec_dropout, latent_dim=args.latent_dim, mode=args.model) for param in model.parameters(): param.data.uniform_(-0.1, 0.1) else: logger.info('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] logger.info("model architecture") print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.warmup == 0: args.beta = 1. else: args.beta = 0.1 criterion = nn.NLLLoss() model.cuda() criterion.cuda() model.train() def variational_loss(input, sents, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(sents, z_samples) nll = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(preds.size(1)) ]) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) if args.test == 1: args.beta = 1 test_data = Dataset(args.test_file) eval(test_data, model, meta_optimizer) exit() t = 0 best_val_nll = 1e5 best_epoch = 0 val_stats = [] epoch = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 logger.info('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_kl_init_final = 0. num_sents = 0 num_words = 0 b = 0 tmp = float((epoch - 1) % args.cycle) / args.cycle cur_lr = args.lr * 0.5 * (1 + np.cos(tmp * np.pi)) for param_group in optimizer.param_groups: param_group['lr'] = cur_lr if (epoch - 1) % args.cycle == 0: args.beta = 0.1 logger.info('KL annealing restart') for i in np.random.permutation(len(train_data)): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_data))) sents, length, batch_size = train_data[i] if args.gpu >= 0: sents = sents.cuda() b += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(sents, None, True) nll_autoreg = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_autoreg += nll_autoreg.data[0] * batch_size nll_autoreg.backward() elif args.model == 'svi': mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(sents, z_samples) nll_svi = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) else: mean, logvar = model._enc_forward(sents) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(sents, z_samples) nll_vae = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_vae += nll_vae.data[0] * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.data[0] * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(sents, z_samples) nll_svi = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) train_kl_init_final += kl_init_final.data[ 0] * batch_size kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], b % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_sents += batch_size num_words += batch_size * length if b % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 logger.info( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.4f, TrainARPPL: %.2f, TrainVAE_NLL: %.4f, TrainVAE_REC: %.4f, TrainVAE_KL: %.4f, TrainVAE_PPL: %.2f, TrainSVI_NLL: %.2f, TrainSVI_REC: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPL: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_data), cur_lr, train_nll_autoreg / num_sents, np.exp(train_nll_autoreg / num_words), (train_nll_vae + train_kl_vae) / num_sents, train_nll_vae / num_sents, train_kl_vae / num_sents, np.exp((train_nll_vae + train_kl_vae) / num_words), (train_nll_svi + train_kl_svi) / num_sents, train_nll_svi / num_sents, train_kl_svi / num_sents, np.exp((train_nll_svi + train_kl_svi) / num_words), train_kl_init_final / num_sents, param_norm, best_val_nll, best_epoch, args.beta, num_sents / (time.time() - start_time))) epoch_train_time = time.time() - start_time logger.info('Time Elapsed: %.1fs' % epoch_train_time) logger.info('--------------------------------') logger.info('Checking validation perf...') logger.record_tabular('Epoch', epoch) logger.record_tabular('Mode', 'Val') logger.record_tabular('LR', cur_lr) logger.record_tabular('Epoch Train Time', epoch_train_time) val_nll = eval(val_data, model, meta_optimizer) val_stats.append(val_nll) logger.info('--------------------------------') logger.info('Checking test perf...') logger.record_tabular('Epoch', epoch) logger.record_tabular('Mode', 'Test') logger.record_tabular('LR', cur_lr) logger.record_tabular('Epoch Train Time', epoch_train_time) test_nll = eval(test_data, model, meta_optimizer) if val_nll < best_val_nll: best_val_nll = val_nll best_epoch = epoch model.cpu() checkpoint = { 'args': args.__dict__, 'model': model, 'val_stats': val_stats } logger.info('Save checkpoint to %s' % checkpoint_path) torch.save(checkpoint, checkpoint_path) model.cuda() else: if epoch >= args.min_epochs: args.decay = 1
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) train_sents = train_data.batch_size.sum() vocab_size = int(train_data.vocab_size) print('Train data: %d batches' % len(train_data)) print('Val data: %d batches' % len(val_data)) print('Word vocab size: %d' % vocab_size) if args.slurm == 0: # cuda.set_device(args.gpu) gpu_id = 0 device = torch.device( f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if args.train_from == '': model = RNNVAE(vocab_size=vocab_size, enc_word_dim=args.enc_word_dim, enc_h_dim=args.enc_h_dim, enc_num_layers=args.enc_num_layers, dec_word_dim=args.dec_word_dim, dec_h_dim=args.dec_h_dim, dec_num_layers=args.dec_num_layers, dec_dropout=args.dec_dropout, latent_dim=args.latent_dim, mode=args.model) for param in model.parameters(): param.data.uniform_(-0.1, 0.1) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) if args.warmup == 0: args.beta = 1. else: args.beta = args.kl_start criterion = nn.NLLLoss(reduce=False) # criterion = nn.NLLLoss() # model.cuda() # criterion.cuda() # model = torch.nn.DataParallel(net, device_ids=[0, 1]) model.to(device) criterion.to(device) model.train() def variational_loss(input, sents, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(sents, z_samples) nll = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(preds.size(1)) ]) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) if args.test == 1: args.beta = 1 test_data = Dataset(args.test_file) eval(args, test_data, model, meta_optimizer, device) exit() t = 0 best_val_nll = 1e5 best_epoch = 0 val_stats = [] epoch = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_kl_init_final = 0. num_sents = 0 num_words = 0 b = 0 for i in np.random.permutation(len(train_data)): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_data))) sents, length, batch_size = train_data[i] length = length.item() batch_size = batch_size.item() if args.gpu >= 0: # sents = sents.cuda() sents = sents.to(device) # batch_size = batch_size.to(device) b += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(sents, None, True) tgt = sents[:, 1:].contiguous() nll_autoreg = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_autoreg += nll_autoreg.item() * batch_size # train_nll_autoreg += nll_autoreg.data[0]*batch_size #old nll_autoreg.backward() elif args.model == 'svi': # mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) # logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).to(device), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).to(device), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_svi = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) else: mean, logvar = model._enc_forward(sents) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_vae = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) # train_nll_vae += nll_vae.data[0]*batch_size#old train_nll_vae += nll_vae.item() * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) # train_kl_vae += kl_vae.data[0]*batch_size#old train_kl_vae += kl_vae.item() * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_svi = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) train_kl_init_final += kl_init_final.data[ 0] * batch_size kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], b % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_sents += batch_size num_words += batch_size * length # num_sents = num_sents.item() # num_words = num_words.item() if b % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 print( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARPPL: %.2f, TrainVAE_PPL: %.2f, TrainVAE_KL: %.4f, TrainVAE_PPLBnd: %.2f, TrainSVI_PPL: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPLBnd: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_data), args.lr, np.exp(train_nll_autoreg / num_words), np.exp( train_nll_vae / num_words), train_kl_vae / num_sents, np.exp((train_nll_vae + train_kl_vae) / num_words), np.exp( train_nll_svi / num_words), train_kl_svi / num_sents, np.exp((train_nll_svi + train_kl_svi) / num_words), train_kl_init_final / num_sents, param_norm, best_val_nll, best_epoch, args.beta, num_sents / (time.time() - start_time))) print('--------------------------------') print('Checking validation perf...') val_nll = eval(args, val_data, model, meta_optimizer, device) val_stats.append(val_nll) # if val_elbo > self.best_val_elbo: # self.not_improved = 0 # self.best_val_elbo = val_elbo # else: # self.not_improved += 1 # if self.not_improved % 5 == 0: # self.current_lr = self.current_lr * self.config.options.lr_decay # print(f'New LR {self.current_lr}') # model.optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) # model.enc_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) # model.dec_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) if val_nll < best_val_nll: not_improved = 0 best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll) if os.path.exists(best_save): os.remove(best_save) best_val_nll = val_nll best_epoch = epoch model.cpu() checkpoint = { 'args': args.__dict__, 'model': model, 'val_stats': val_stats } print('Savaeng checkpoint to %s' % args.checkpoint_path) best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll) torch.save(checkpoint, best_save) # model.cuda() model.to(device) else: not_improved += 1 if not_improved % 5 == 0: not_improved = 0 args.lr = args.lr * args.lr_decay print(f'New LR: {args.lr}') for param_group in optimizer.param_groups: param_group['lr'] = args.lr