コード例 #1
0
def run(e):
    global best_dev_res, test_bm_res, test_avg_res

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(train_path=e.config.train_file,
                                   eval_path=e.config.eval_file,
                                   experiment=e)
    data, W = dp.process()

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = vgvae(vocab_size=len(data.vocab),
                  embed_dim=e.config.edim if W is None else W.shape[1],
                  embed_init=W,
                  experiment=e)

    start_epoch = true_it = 0
    if e.config.resume:
        start_epoch, _, best_dev_res, test_avg_res = \
            model.load(name="latest")
        if e.config.use_cuda:
            model.cuda()
            e.log.info("transferred model to gpu")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, best dev res: {:.3f}, test avg res: {:.3f}".format(
                start_epoch, true_it, best_dev_res, test_avg_res))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    if e.config.decoder_type.startswith("bag"):
        minibatcher = data_utils.bow_minibatcher
        e.log.info("using BOW batcher")
    else:
        minibatcher = data_utils.minibatcher
        e.log.info("using sequential batcher")

    train_batch = minibatcher(
        data1=data.train_data[0],
        data2=data.train_data[1],
        vocab_size=len(data.vocab),
        batch_size=e.config.batch_size,
        score_func=model.score,
        shuffle=True,
        mega_batch=0 if not e.config.resume else e.config.mb,
        p_scramble=e.config.ps)

    evaluator = train_helper.evaluator(model, e)

    e.log.info("Training start ...")
    train_stats = train_helper.tracker([
        "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss",
        "dp_loss"
    ])

    for epoch in range(start_epoch, e.config.n_epoch):
        if epoch > 1 and train_batch.mega_batch != e.config.mb:
            train_batch.mega_batch = e.config.mb
            train_batch._reset()
        e.log.info("current mega batch: {}".format(train_batch.mega_batch))
        for it, (s1, m1, s2, m2, t1, tm1, t2, tm2,
                 n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2, _) in \
                enumerate(train_batch):
            true_it = it + 1 + epoch * len(train_batch)

            loss, vkl, gkl, rec_logloss, para_logloss, wploss, dploss = \
                model(s1, m1, s2, m2, t1, tm1, t2, tm2,
                      n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2,
                      e.config.vmkl, e.config.gmkl,
                      epoch > 1 and e.config.dratio and e.config.mb > 1)
            model.optimize(loss)

            train_stats.update(
                {
                    "loss": loss,
                    "vmf_kl": vkl,
                    "gauss_kl": gkl,
                    "para_logloss": para_logloss,
                    "rec_logloss": rec_logloss,
                    "wploss": wploss,
                    "dp_loss": dploss
                }, len(s1))

            if (true_it + 1) % e.config.print_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                summarization = train_stats.summarize(
                    "epoch: {}, it: {} (max: {}), kl_temp: {:.2E}|{:.2E}".
                    format(epoch, it, len(train_batch), e.config.vmkl,
                           e.config.gmkl))
                e.log.info(summarization)
                if e.config.summarize:
                    for name, value in train_stats.stats.items():
                        writer.add_scalar("train/" + name, value, true_it)
                train_stats.reset()

            if (true_it + 1) % e.config.eval_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                dev_stats, _, dev_res, _ = evaluator.evaluate(
                    data.dev_data, 'pred')

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                if e.config.summarize:
                    writer.add_scalar("dev/pearson", dev_stats[EVAL_YEAR][1],
                                      true_it)
                    writer.add_scalar("dev/spearman", dev_stats[EVAL_YEAR][2],
                                      true_it)

                if best_dev_res < dev_res:
                    best_dev_res = dev_res

                    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

                    test_stats, test_bm_res, test_avg_res, test_avg_s = \
                        evaluator.evaluate(data.test_data, 'pred')

                    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
                    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

                    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
                        evaluator.evaluate(data.test_data, 'predz')
                    e.log.info(
                        "Summary - benchmark: {:.4f}, test avg: {:.4f}".format(
                            tz_bm_res, tz_avg_res))
                    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

                    model.save(dev_avg=best_dev_res,
                               dev_perf=dev_stats,
                               test_avg=test_avg_res,
                               test_perf=test_stats,
                               iteration=true_it,
                               epoch=epoch)

                    if e.config.summarize:
                        for year, stats in test_stats.items():
                            writer.add_scalar("test/{}_pearson".format(year),
                                              stats[1], true_it)
                            writer.add_scalar("test/{}_spearman".format(year),
                                              stats[2], true_it)

                        writer.add_scalar("test/avg_pearson", test_avg_res,
                                          true_it)
                        writer.add_scalar("test/avg_spearman", test_avg_s,
                                          true_it)
                        writer.add_scalar("test/STSBenchmark_pearson",
                                          test_bm_res, true_it)
                        writer.add_scalar("dev/best_pearson", best_dev_res,
                                          true_it)

                        writer.add_scalar("testz/avg_pearson", tz_avg_res,
                                          true_it)
                        writer.add_scalar("testz/avg_spearman", tz_avg_s,
                                          true_it)
                        writer.add_scalar("testz/STSBenchmark_pearson",
                                          tz_bm_res, true_it)
                train_stats.reset()
                e.log.info("best dev result: {:.4f}, "
                           "STSBenchmark result: {:.4f}, "
                           "test average result: {:.4f}".format(
                               best_dev_res, test_bm_res, test_avg_res))

        model.save(dev_avg=best_dev_res,
                   dev_perf=dev_stats,
                   test_avg=test_avg_res,
                   test_perf=test_stats,
                   iteration=true_it,
                   epoch=epoch + 1,
                   name="latest")

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

    test_stats, test_bm_res, test_avg_res, test_avg_s = \
        evaluator.evaluate(data.test_data, 'pred')

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
        evaluator.evaluate(data.test_data, 'predz')
    e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}".format(
        tz_bm_res, tz_avg_res))
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)
コード例 #2
0

save_dict = torch.load(args.save_file,
                       map_location=lambda storage, loc: storage)

config = save_dict['config']
checkpoint = save_dict['state_dict']
config.debug = True

with open(args.vocab_file, "rb") as fp:
    W, vocab = pickle.load(fp)

with train_helper.experiment(config, config.save_prefix) as e:
    e.log.info("vocab loaded from: {}".format(args.vocab_file))
    model = models.vgvae(vocab_size=len(vocab),
                         embed_dim=e.config.edim if W is None else W.shape[1],
                         embed_init=W,
                         experiment=e)
    model.eval()
    model.load(checkpointed_state_dict=checkpoint)
    e.log.info(model)

    def encode(d):
        global vocab, batch_size
        new_d = [[vocab.get(w, 0) for w in s.split(" ")] for s in d]
        all_y_vecs = []
        all_z_vecs = []

        for s1, m1, s2, m2, _, _, _, _, \
                _, _, _, _, _, _, _, _, _ in \
                tqdm(data_utils.minibatcher(
                    data1=np.array(new_d),
コード例 #3
0
def run(e):
    global best_dev_bleu, test_bleu

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(train_path=e.config.train_path,
                                   experiment=e)
    data, W = dp.process()

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = models.vgvae(vocab_size=len(data.vocab),
                         embed_dim=e.config.edim if W is None else W.shape[1],
                         embed_init=W,
                         experiment=e)

    start_epoch = true_it = 0
    best_dev_stats = test_stats = None
    if e.config.resume:
        start_epoch, _, best_dev_bleu, test_bleu = \
            model.load(name="latest")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, best dev bleu: {:.3f}, test bleu: {:.3f}, ".format(
                start_epoch, true_it, best_dev_bleu, test_bleu))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    train_batch = data_utils.minibatcher(data1=data.train_data[0],
                                         tag1=data.train_tag[0],
                                         data2=data.train_data[1],
                                         tag2=data.train_tag[1],
                                         tag_bucket=data.tag_bucket,
                                         vocab_size=len(data.vocab),
                                         batch_size=e.config.batch_size,
                                         shuffle=True,
                                         p_replace=e.config.wr,
                                         p_scramble=e.config.ps)

    dev_eval = train_helper.evaluator(e.config.dev_inp_path,
                                      e.config.dev_ref_path, model, data.vocab,
                                      data.inv_vocab, e)
    test_eval = train_helper.evaluator(e.config.test_inp_path,
                                       e.config.test_ref_path, model,
                                       data.vocab, data.inv_vocab, e)

    e.log.info("Training start ...")
    train_stats = train_helper.tracker([
        "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss"
    ])

    for epoch in range(start_epoch, e.config.n_epoch):
        for it, (s1, sr1, m1, s2, sr2, m2, t1, tm1, t2, tm2, _) in \
                enumerate(train_batch):
            true_it = it + 1 + epoch * len(train_batch)

            loss, kl, kl2, rec_logloss, para_logloss, wploss = \
                model(s1, sr1, m1, s2, sr2, m2, t1, tm1,
                      t2, tm2, e.config.vmkl, e.config.gmkl)
            model.optimize(loss)
            train_stats.update(
                {
                    "loss": loss,
                    "vmf_kl": kl,
                    "gauss_kl": kl2,
                    "para_logloss": para_logloss,
                    "rec_logloss": rec_logloss,
                    "wploss": wploss
                }, len(s1))

            if (true_it + 1) % e.config.print_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                summarization = train_stats.summarize(
                    "epoch: {}, it: {} (max: {}), kl_temp(v|g): {:.2E}|{:.2E}".
                    format(epoch, it, len(train_batch), e.config.vmkl,
                           e.config.gmkl))
                e.log.info(summarization)
                if e.config.summarize:
                    for name, value in train_stats.stats.items():
                        writer.add_scalar("train/" + name, value, true_it)
                train_stats.reset()

            if (true_it + 1) % e.config.eval_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                dev_stats, dev_bleu = dev_eval.evaluate("gen_dev")

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                if e.config.summarize:
                    for name, value in dev_stats.items():
                        writer.add_scalar("dev/" + name, value, true_it)

                if best_dev_bleu < dev_bleu:
                    best_dev_bleu = dev_bleu
                    best_dev_stats = dev_stats

                    e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                    test_stats, test_bleu = test_eval.evaluate("gen_test")

                    e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                    model.save(dev_bleu=best_dev_bleu,
                               dev_stats=best_dev_stats,
                               test_bleu=test_bleu,
                               test_stats=test_stats,
                               iteration=true_it,
                               epoch=epoch)

                    if e.config.summarize:
                        for name, value in test_stats.items():
                            writer.add_scalar("test/" + name, value, true_it)

                e.log.info("best dev bleu: {:.4f}, test bleu: {:.4f}".format(
                    best_dev_bleu, test_bleu))

        model.save(dev_bleu=best_dev_bleu,
                   dev_stats=best_dev_stats,
                   test_bleu=test_bleu,
                   test_stats=test_stats,
                   iteration=true_it,
                   epoch=epoch + 1,
                   name="latest")

        time_per_epoch = (e.elapsed_time / (epoch - start_epoch + 1))
        time_in_need = time_per_epoch * (e.config.n_epoch - epoch - 1)
        e.log.info("elapsed time: {:.2f}(h), "
                   "time per epoch: {:.2f}(h), "
                   "time needed to finish: {:.2f}(h)".format(
                       e.elapsed_time, time_per_epoch, time_in_need))

        if time_per_epoch + e.elapsed_time > 3.7 and e.config.auto_disconnect:
            exit(1)

    test_gen_stats, test_res = test_eval.evaluate("gen_test")