def evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test"): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_output_dir = args.output_dir if subset == 'test': eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True) elif subset == 'train': eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False) logger.info("***** Running evaluation on {} dataset *****".format(subset)) if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) args.per_gpu_eval_batch_size = 1 args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) model_vae.eval() model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae # Take care of distributed/parallel training mi = calc_mi(model_vae, eval_dataloader, args) au = calc_au(model_vae, eval_dataloader, delta=0.01, args=args)[0] ppl, elbo, nll, kl = calc_iwnll(model_vae, eval_dataloader, args, ns=100) result = { "perplexity": ppl, "elbo": elbo, "kl": kl, "nll": nll, "au": au, "mi": mi } output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) row = { 'PartitionKey': 'MILU_Rule_Rule_Template', 'RowKey': str(datetime.now()), 'ExpName' : args.ExpName, 'test_perplexity': str( ppl ), 'test_elbo': str( elbo ), 'test_nll': str(nll), 'test_au': str(au), 'test_mi': str(mi) } # pdb.set_trace() ts.insert_entity(table_name, row) return result
def main(args): global logging debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.reset_dec: vae.decoder.reset_parameters(model_init, emb_init) if args.eval: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.reconstruct_from != "": print("begin decoding") sys.stdout.flush() vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) return if args.opt == "sgd": enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 vae.train() start = time.time() train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = args.beta enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.iw_train_nsamples < 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) else: loss, loss_rc, loss_kl = vae.loss_iw( batch_data, kl_weight, nsamples=args.iw_train_nsamples, ns=ns) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss.item() * batch_size if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) #sys.stdout.flush() report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) # print(au_var) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
def main(args): global logging logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.eval: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, test_labels_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) calc_iwnll(vae, test_data_batch, args) return if args.discriminator == "linear": discriminator = LinearDiscriminator(args, vae.encoder).to(device) elif args.discriminator == "mlp": discriminator = MLPDiscriminator(args, vae.encoder).to(device) if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 discriminator.train() start = time.time() kl_weight = args.kl_start if args.warm_up > 0: anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) else: anneal_rate = 0 dim_target_kl = args.target_kl / float(args.nz) train_data_batch, train_labels_batch = train_data.create_data_batch_labels( batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_labels_batch = val_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) acc_cnt = 1 acc_loss = 0. for epoch in range(args.epochs): report_loss = 0 report_correct = report_num_words = report_num_sents = 0 acc_batch_size = 0 optimizer.zero_grad() for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_labels = train_labels_batch[i] batch_labels = [int(x) for x in batch_labels] batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size acc_batch_size += batch_size # (batch_size) loss, correct = discriminator.get_performance( batch_data, batch_labels) acc_loss = acc_loss + loss.sum() if acc_cnt % args.update_every == 0: acc_loss = acc_loss / acc_batch_size acc_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() acc_cnt = 0 acc_loss = 0 acc_batch_size = 0 acc_cnt += 1 report_loss += loss.sum().item() report_correct += correct if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \ 'time %.2fs' % (epoch, iter_, train_loss, report_correct / report_num_sents, time.time() - start)) #sys.stdout.flush() iter_ += 1 logging('lr {}'.format(opt_dict["lr"])) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) # print(au_var) if loss < best_loss: logging('update best loss') best_loss = loss best_acc = acc torch.save(discriminator.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay discriminator.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) opt_dict['lr'] = opt_dict["lr"] elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) opt_dict['lr'] = opt_dict["lr"] else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) discriminator.train() # compute importance weighted estimate of log p(x) discriminator.load_state_dict(torch.load(args.save_path)) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
def main(args, args_model): global logging eval_mode = (args.reconstruct_from != "" or args.eval or args.eval_iw_elbo or args.eval_valid_elbo or args.export_avg_loss_per_ts or args.study_pooling ) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=eval_mode) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None): with open(args.vocab_file, 'r', encoding='utf-8') as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max((len(train_data) // args.batch_size) // 10, 1) device = torch.device("cuda" if args.cuda else "cpu") vae = create_model(vocab, args, args_model, logging, eval_mode) if args.eval: logging('begin evaluation') vae.eval() with torch.no_grad(): test_data_batch = val_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args, ns=250) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.eval_iw_elbo: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iw_elbo(vae, test_data_batch, args) logging('iw ELBo: %.4f, iw PPL*: %.4f' % (nll, ppl)) return if args.eval_valid_elbo: logging('begin evaluation on validation set') vae.load_state_dict(torch.load(args.load_path)) vae.eval() val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) logging('nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.study_pooling: vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): data_batch = train_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) model_dir = os.path.dirname(args.load_path) archive_npy = os.path.join(model_dir, 'pooling.npy') random.shuffle(data_batch) #logs = study_pooling(vae, data_batch, "TRAIN", args, min_doc_size=16) logs = study_pooling(vae, data_batch, args, min_doc_size=4) logs['exp_dir'] = model_dir np.save(archive_npy, logs) return if args.export_avg_loss_per_ts: print("MODEL") print(vae) export_avg_loss_per_ts( vae, train_data, device, args.batch_size, args.load_path, args.export_avg_loss_per_ts, ) return if args.reconstruct_from != "": print("begin decoding") vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): if args.reconstruct_add_labels_to_source: test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=args.reconstruct_batch_size, device=device, batch_first=True, deterministic=True) c = list(zip(test_data_batch, test_labels_batch)) #random.shuffle(c) test_data_batch, test_labels_batch = zip(*c) else: test_data_batch = test_data.create_data_batch( batch_size=args.reconstruct_batch_size, device=device, batch_first=True) test_labels_batch = None #random.shuffle(test_data_batch) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to, test_labels_batch, args.reconstruct_max_examples, args.force_absolute_length, args.no_unk) return if args.freeze_encoder_exc: assert args.enc_type == 'lstm' enc_params = vae.encoder.linear.parameters() else: enc_params = vae.encoder.parameters() dec_params = vae.decoder.parameters() if args.opt == 'sgd': optimizer_fn = optim.SGD elif args.opt == 'adam': optimizer_fn = optim.Adam else: raise ValueError("optimizer not supported") def optimizer_fn_(params): return optimizer_fn(params, lr=args.lr, momentum=args.momentum) enc_optimizer = optimizer_fn_(enc_params) dec_optimizer = optimizer_fn_(dec_params) iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 vae.train() start = time.time() kl_weight = args.kl_start if args.warm_up > 0: anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) else: anneal_rate = 0 dim_target_kl = args.target_kl / float(args.nz) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = min(1.0, kl_weight + anneal_rate) enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.fb == 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) elif args.fb == 1: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples, sum_over_len=False) kl_mask = (loss_kl > args.target_kl).float() loss_rc = loss_rc.sum(-1) loss = loss_rc + kl_mask * kl_weight * loss_kl elif args.fb == 2: mu, logvar = vae.encoder(batch_data) z = vae.encoder.reparameterize(mu, logvar, args.nsamples) loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) kl_mask = (loss_kl > dim_target_kl).float() fake_loss_kl = (kl_mask * loss_kl).sum(dim=1) loss_rc = vae.decoder.reconstruct_error(batch_data, z).mean(dim=1) loss = loss_rc + kl_weight * fake_loss_kl loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not args.freeze_encoder: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss_rc.item() + loss_kl.item() if iter_ % log_niter == 0: train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) logging('lr {}'.format(opt_dict["lr"])) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch)