def run(e): global best_dev_res, test_bm_res, test_avg_res e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) dp = data_utils.data_processor(train_path=e.config.train_file, eval_path=e.config.eval_file, experiment=e) data, W = dp.process() e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) model = vgvae(vocab_size=len(data.vocab), embed_dim=e.config.edim if W is None else W.shape[1], embed_init=W, experiment=e) start_epoch = true_it = 0 if e.config.resume: start_epoch, _, best_dev_res, test_avg_res = \ model.load(name="latest") if e.config.use_cuda: model.cuda() e.log.info("transferred model to gpu") e.log.info( "resumed from previous checkpoint: start epoch: {}, " "iteration: {}, best dev res: {:.3f}, test avg res: {:.3f}".format( start_epoch, true_it, best_dev_res, test_avg_res)) e.log.info(model) e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) if e.config.summarize: writer = SummaryWriter(e.experiment_dir) if e.config.decoder_type.startswith("bag"): minibatcher = data_utils.bow_minibatcher e.log.info("using BOW batcher") else: minibatcher = data_utils.minibatcher e.log.info("using sequential batcher") train_batch = minibatcher( data1=data.train_data[0], data2=data.train_data[1], vocab_size=len(data.vocab), batch_size=e.config.batch_size, score_func=model.score, shuffle=True, mega_batch=0 if not e.config.resume else e.config.mb, p_scramble=e.config.ps) evaluator = train_helper.evaluator(model, e) e.log.info("Training start ...") train_stats = train_helper.tracker([ "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss", "dp_loss" ]) for epoch in range(start_epoch, e.config.n_epoch): if epoch > 1 and train_batch.mega_batch != e.config.mb: train_batch.mega_batch = e.config.mb train_batch._reset() e.log.info("current mega batch: {}".format(train_batch.mega_batch)) for it, (s1, m1, s2, m2, t1, tm1, t2, tm2, n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2, _) in \ enumerate(train_batch): true_it = it + 1 + epoch * len(train_batch) loss, vkl, gkl, rec_logloss, para_logloss, wploss, dploss = \ model(s1, m1, s2, m2, t1, tm1, t2, tm2, n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2, e.config.vmkl, e.config.gmkl, epoch > 1 and e.config.dratio and e.config.mb > 1) model.optimize(loss) train_stats.update( { "loss": loss, "vmf_kl": vkl, "gauss_kl": gkl, "para_logloss": para_logloss, "rec_logloss": rec_logloss, "wploss": wploss, "dp_loss": dploss }, len(s1)) if (true_it + 1) % e.config.print_every == 0 or \ (true_it + 1) % len(train_batch) == 0: summarization = train_stats.summarize( "epoch: {}, it: {} (max: {}), kl_temp: {:.2E}|{:.2E}". format(epoch, it, len(train_batch), e.config.vmkl, e.config.gmkl)) e.log.info(summarization) if e.config.summarize: for name, value in train_stats.stats.items(): writer.add_scalar("train/" + name, value, true_it) train_stats.reset() if (true_it + 1) % e.config.eval_every == 0 or \ (true_it + 1) % len(train_batch) == 0: e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) dev_stats, _, dev_res, _ = evaluator.evaluate( data.dev_data, 'pred') e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) if e.config.summarize: writer.add_scalar("dev/pearson", dev_stats[EVAL_YEAR][1], true_it) writer.add_scalar("dev/spearman", dev_stats[EVAL_YEAR][2], true_it) if best_dev_res < dev_res: best_dev_res = dev_res e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25) test_stats, test_bm_res, test_avg_res, test_avg_s = \ evaluator.evaluate(data.test_data, 'pred') e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25) e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25) tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \ evaluator.evaluate(data.test_data, 'predz') e.log.info( "Summary - benchmark: {:.4f}, test avg: {:.4f}".format( tz_bm_res, tz_avg_res)) e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25) model.save(dev_avg=best_dev_res, dev_perf=dev_stats, test_avg=test_avg_res, test_perf=test_stats, iteration=true_it, epoch=epoch) if e.config.summarize: for year, stats in test_stats.items(): writer.add_scalar("test/{}_pearson".format(year), stats[1], true_it) writer.add_scalar("test/{}_spearman".format(year), stats[2], true_it) writer.add_scalar("test/avg_pearson", test_avg_res, true_it) writer.add_scalar("test/avg_spearman", test_avg_s, true_it) writer.add_scalar("test/STSBenchmark_pearson", test_bm_res, true_it) writer.add_scalar("dev/best_pearson", best_dev_res, true_it) writer.add_scalar("testz/avg_pearson", tz_avg_res, true_it) writer.add_scalar("testz/avg_spearman", tz_avg_s, true_it) writer.add_scalar("testz/STSBenchmark_pearson", tz_bm_res, true_it) train_stats.reset() e.log.info("best dev result: {:.4f}, " "STSBenchmark result: {:.4f}, " "test average result: {:.4f}".format( best_dev_res, test_bm_res, test_avg_res)) model.save(dev_avg=best_dev_res, dev_perf=dev_stats, test_avg=test_avg_res, test_perf=test_stats, iteration=true_it, epoch=epoch + 1, name="latest") e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25) test_stats, test_bm_res, test_avg_res, test_avg_s = \ evaluator.evaluate(data.test_data, 'pred') e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25) e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25) tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \ evaluator.evaluate(data.test_data, 'predz') e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}".format( tz_bm_res, tz_avg_res)) e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)
save_dict = torch.load(args.save_file, map_location=lambda storage, loc: storage) config = save_dict['config'] checkpoint = save_dict['state_dict'] config.debug = True with open(args.vocab_file, "rb") as fp: W, vocab = pickle.load(fp) with train_helper.experiment(config, config.save_prefix) as e: e.log.info("vocab loaded from: {}".format(args.vocab_file)) model = models.vgvae(vocab_size=len(vocab), embed_dim=e.config.edim if W is None else W.shape[1], embed_init=W, experiment=e) model.eval() model.load(checkpointed_state_dict=checkpoint) e.log.info(model) def encode(d): global vocab, batch_size new_d = [[vocab.get(w, 0) for w in s.split(" ")] for s in d] all_y_vecs = [] all_z_vecs = [] for s1, m1, s2, m2, _, _, _, _, \ _, _, _, _, _, _, _, _, _ in \ tqdm(data_utils.minibatcher( data1=np.array(new_d),
def run(e): global best_dev_bleu, test_bleu e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) dp = data_utils.data_processor(train_path=e.config.train_path, experiment=e) data, W = dp.process() e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) model = models.vgvae(vocab_size=len(data.vocab), embed_dim=e.config.edim if W is None else W.shape[1], embed_init=W, experiment=e) start_epoch = true_it = 0 best_dev_stats = test_stats = None if e.config.resume: start_epoch, _, best_dev_bleu, test_bleu = \ model.load(name="latest") e.log.info( "resumed from previous checkpoint: start epoch: {}, " "iteration: {}, best dev bleu: {:.3f}, test bleu: {:.3f}, ".format( start_epoch, true_it, best_dev_bleu, test_bleu)) e.log.info(model) e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) if e.config.summarize: writer = SummaryWriter(e.experiment_dir) train_batch = data_utils.minibatcher(data1=data.train_data[0], tag1=data.train_tag[0], data2=data.train_data[1], tag2=data.train_tag[1], tag_bucket=data.tag_bucket, vocab_size=len(data.vocab), batch_size=e.config.batch_size, shuffle=True, p_replace=e.config.wr, p_scramble=e.config.ps) dev_eval = train_helper.evaluator(e.config.dev_inp_path, e.config.dev_ref_path, model, data.vocab, data.inv_vocab, e) test_eval = train_helper.evaluator(e.config.test_inp_path, e.config.test_ref_path, model, data.vocab, data.inv_vocab, e) e.log.info("Training start ...") train_stats = train_helper.tracker([ "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss" ]) for epoch in range(start_epoch, e.config.n_epoch): for it, (s1, sr1, m1, s2, sr2, m2, t1, tm1, t2, tm2, _) in \ enumerate(train_batch): true_it = it + 1 + epoch * len(train_batch) loss, kl, kl2, rec_logloss, para_logloss, wploss = \ model(s1, sr1, m1, s2, sr2, m2, t1, tm1, t2, tm2, e.config.vmkl, e.config.gmkl) model.optimize(loss) train_stats.update( { "loss": loss, "vmf_kl": kl, "gauss_kl": kl2, "para_logloss": para_logloss, "rec_logloss": rec_logloss, "wploss": wploss }, len(s1)) if (true_it + 1) % e.config.print_every == 0 or \ (true_it + 1) % len(train_batch) == 0: summarization = train_stats.summarize( "epoch: {}, it: {} (max: {}), kl_temp(v|g): {:.2E}|{:.2E}". format(epoch, it, len(train_batch), e.config.vmkl, e.config.gmkl)) e.log.info(summarization) if e.config.summarize: for name, value in train_stats.stats.items(): writer.add_scalar("train/" + name, value, true_it) train_stats.reset() if (true_it + 1) % e.config.eval_every == 0 or \ (true_it + 1) % len(train_batch) == 0: e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) dev_stats, dev_bleu = dev_eval.evaluate("gen_dev") e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) if e.config.summarize: for name, value in dev_stats.items(): writer.add_scalar("dev/" + name, value, true_it) if best_dev_bleu < dev_bleu: best_dev_bleu = dev_bleu best_dev_stats = dev_stats e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) test_stats, test_bleu = test_eval.evaluate("gen_test") e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) model.save(dev_bleu=best_dev_bleu, dev_stats=best_dev_stats, test_bleu=test_bleu, test_stats=test_stats, iteration=true_it, epoch=epoch) if e.config.summarize: for name, value in test_stats.items(): writer.add_scalar("test/" + name, value, true_it) e.log.info("best dev bleu: {:.4f}, test bleu: {:.4f}".format( best_dev_bleu, test_bleu)) model.save(dev_bleu=best_dev_bleu, dev_stats=best_dev_stats, test_bleu=test_bleu, test_stats=test_stats, iteration=true_it, epoch=epoch + 1, name="latest") time_per_epoch = (e.elapsed_time / (epoch - start_epoch + 1)) time_in_need = time_per_epoch * (e.config.n_epoch - epoch - 1) e.log.info("elapsed time: {:.2f}(h), " "time per epoch: {:.2f}(h), " "time needed to finish: {:.2f}(h)".format( e.elapsed_time, time_per_epoch, time_in_need)) if time_per_epoch + e.elapsed_time > 3.7 and e.config.auto_disconnect: exit(1) test_gen_stats, test_res = test_eval.evaluate("gen_test")