Esempio n. 1
0
def run(e):
    global best_dev_res, test_res

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(experiment=e)
    data, W = dp.process()

    label_logvar1_buffer = \
        train_helper.prior_buffer(data.train[0], e.config.zsize,
                                  experiment=e,
                                  freq=e.config.ufl,
                                  name="label_logvar1",
                                  init_path=e.config.prior_file)
    label_mean1_buffer = \
        train_helper.prior_buffer(data.train[0], e.config.zsize,
                                  experiment=e,
                                  freq=e.config.ufl,
                                  name="label_mean1",
                                  init_path=e.config.prior_file)

    label_logvar2_buffer = \
        train_helper.prior_buffer(data.train[0], e.config.ysize,
                                  experiment=e,
                                  freq=e.config.ufl,
                                  name="label_logvar2",
                                  init_path=e.config.prior_file)
    label_mean2_buffer = \
        train_helper.prior_buffer(data.train[0], e.config.ysize,
                                  experiment=e,
                                  freq=e.config.ufl,
                                  name="label_mean2",
                                  init_path=e.config.prior_file)

    all_buffer = [
        label_logvar1_buffer, label_mean1_buffer, label_logvar2_buffer,
        label_mean2_buffer
    ]

    e.log.info("labeled buffer size: logvar1: {}, mean1: {}, "
               "logvar2: {}, mean2: {}".format(len(label_logvar1_buffer),
                                               len(label_mean1_buffer),
                                               len(label_logvar2_buffer),
                                               len(label_mean2_buffer)))

    if e.config.use_unlabel:
        unlabel_logvar1_buffer = \
            train_helper.prior_buffer(data.unlabel[0], e.config.zsize,
                                      experiment=e,
                                      freq=e.config.ufu,
                                      name="unlabel_logvar1",
                                      init_path=e.config.prior_file)
        unlabel_mean1_buffer = \
            train_helper.prior_buffer(data.unlabel[0], e.config.zsize,
                                      experiment=e,
                                      freq=e.config.ufu,
                                      name="unlabel_mean1",
                                      init_path=e.config.prior_file)

        unlabel_logvar2_buffer = \
            train_helper.prior_buffer(data.unlabel[0], e.config.ysize,
                                      experiment=e,
                                      freq=e.config.ufu,
                                      name="unlabel_logvar2",
                                      init_path=e.config.prior_file)
        unlabel_mean2_buffer = \
            train_helper.prior_buffer(data.unlabel[0], e.config.ysize,
                                      experiment=e,
                                      freq=e.config.ufu,
                                      name="unlabel_mean2",
                                      init_path=e.config.prior_file)

        all_buffer += [
            unlabel_logvar1_buffer, unlabel_mean1_buffer,
            unlabel_logvar2_buffer, unlabel_mean2_buffer
        ]

        e.log.info("unlabeled buffer size: logvar1: {}, mean1: {}, "
                   "logvar2: {}, mean2: {}".format(len(unlabel_logvar1_buffer),
                                                   len(unlabel_mean1_buffer),
                                                   len(unlabel_logvar2_buffer),
                                                   len(unlabel_mean2_buffer)))

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = vsl_gg(word_vocab_size=len(data.vocab),
                   char_vocab_size=len(data.char_vocab),
                   n_tags=len(data.tag_vocab),
                   embed_dim=e.config.edim if W is None else W.shape[1],
                   embed_init=W,
                   experiment=e)

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    label_batch = data_utils.minibatcher(word_data=data.train[0],
                                         char_data=data.train[1],
                                         label=data.train[2],
                                         batch_size=e.config.batch_size,
                                         shuffle=True)

    if e.config.use_unlabel:
        unlabel_batch = data_utils.minibatcher(
            word_data=data.unlabel[0],
            char_data=data.unlabel[1],
            label=data.unlabel[0],
            batch_size=e.config.unlabel_batch_size,
            shuffle=True)

    evaluator = train_helper.evaluator(data.inv_tag_vocab, model, e)

    e.log.info("Training start ...")
    label_stats = train_helper.tracker(
        ["loss", "logloss", "kl_div", "sup_loss"])
    unlabel_stats = train_helper.tracker(["loss", "logloss", "kl_div"])

    for it in range(e.config.n_iter):
        model.train()
        kl_temp = train_helper.get_kl_temp(e.config.klr, it, 1.0)

        try:
            l_data, l_mask, l_char, l_char_mask, l_label, l_ixs = \
                next(label_batch)
        except StopIteration:
            pass

        lp_logvar1 = label_logvar1_buffer[l_ixs]
        lp_mean1 = label_mean1_buffer[l_ixs]
        lp_logvar2 = label_logvar2_buffer[l_ixs]
        lp_mean2 = label_mean2_buffer[l_ixs]

        l_loss, l_logloss, l_kld, sup_loss, \
            lq_mean1, lq_logvar1, lq_mean2, lq_logvar2, _ = \
            model(l_data, l_mask, l_char, l_char_mask,
                  l_label, [lp_mean1, lp_mean2], [lp_logvar1, lp_logvar2],
                  kl_temp)

        label_logvar1_buffer.update_buffer(l_ixs, lq_logvar1, l_mask.sum(-1))
        label_mean1_buffer.update_buffer(l_ixs, lq_mean1, l_mask.sum(-1))

        label_logvar2_buffer.update_buffer(l_ixs, lq_logvar2, l_mask.sum(-1))
        label_mean2_buffer.update_buffer(l_ixs, lq_mean2, l_mask.sum(-1))

        label_stats.update(
            {
                "loss": l_loss,
                "logloss": l_logloss,
                "kl_div": l_kld,
                "sup_loss": sup_loss
            }, l_mask.sum())

        if not e.config.use_unlabel:
            model.optimize(l_loss)

        else:
            try:
                u_data, u_mask, u_char, u_char_mask, _, u_ixs = \
                    next(unlabel_batch)
            except StopIteration:
                pass

            up_logvar1 = unlabel_logvar1_buffer[u_ixs]
            up_mean1 = unlabel_mean1_buffer[u_ixs]

            up_logvar2 = unlabel_logvar2_buffer[u_ixs]
            up_mean2 = unlabel_mean2_buffer[u_ixs]

            u_loss, u_logloss, u_kld, _, \
                uq_mean1, uq_logvar1, uq_mean2, uq_logvar2, _ = \
                model(u_data, u_mask, u_char, u_char_mask,
                      None, [up_mean1, up_mean2], [up_logvar1, up_logvar2],
                      kl_temp)

            unlabel_logvar1_buffer.update_buffer(u_ixs, uq_logvar1,
                                                 u_mask.sum(-1))
            unlabel_mean1_buffer.update_buffer(u_ixs, uq_mean1, u_mask.sum(-1))

            unlabel_logvar2_buffer.update_buffer(u_ixs, uq_logvar2,
                                                 u_mask.sum(-1))
            unlabel_mean2_buffer.update_buffer(u_ixs, uq_mean2, u_mask.sum(-1))

            unlabel_stats.update(
                {
                    "loss": u_loss,
                    "logloss": u_logloss,
                    "kl_div": u_kld
                }, u_mask.sum())

            model.optimize(l_loss + e.config.ur * u_loss)

        if (it + 1) % e.config.print_every == 0:
            summary = label_stats.summarize(
                "it: {} (max: {}), kl_temp: {:.2f}, labeled".format(
                    it + 1, len(label_batch), kl_temp))
            if e.config.use_unlabel:
                summary += unlabel_stats.summarize(", unlabeled")
            e.log.info(summary)
            if e.config.summarize:
                writer.add_scalar("label/kl_temp", kl_temp, it)
                for name, value in label_stats.stats.items():
                    writer.add_scalar("label/" + name, value, it)
                if e.config.use_unlabel:
                    for name, value in unlabel_stats.stats.items():
                        writer.add_scalar("unlabel/" + name, value, it)
            label_stats.reset()
            unlabel_stats.reset()
        if (it + 1) % e.config.eval_every == 0:

            e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

            dev_perf, dev_res = evaluator.evaluate(data.dev)

            e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

            if e.config.summarize:
                for n, v in dev_perf.items():
                    writer.add_scalar("dev/" + n, v, it)

            if best_dev_res < dev_res:
                best_dev_res = dev_res

                e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                test_perf, test_res = evaluator.evaluate(data.test)

                e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                model.save(dev_perf=dev_perf,
                           test_perf=test_perf,
                           iteration=it)

                if e.config.save_prior:
                    for buf in all_buffer:
                        buf.save()

                if e.config.summarize:
                    writer.add_scalar("dev/best_result", best_dev_res, it)
                    for n, v in test_perf.items():
                        writer.add_scalar("test/" + n, v, it)
            e.log.info("best dev result: {:.4f}, "
                       "test result: {:.4f}, ".format(best_dev_res, test_res))
            label_stats.reset()
            unlabel_stats.reset()
Esempio n. 2
0
def run(e):
    global best_dev_res, test_bm_res, test_avg_res

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(train_path=e.config.train_file,
                                   eval_path=e.config.eval_file,
                                   experiment=e)
    data, W = dp.process()

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = vgvae(vocab_size=len(data.vocab),
                  embed_dim=e.config.edim if W is None else W.shape[1],
                  embed_init=W,
                  experiment=e)

    start_epoch = true_it = 0
    if e.config.resume:
        start_epoch, _, best_dev_res, test_avg_res = \
            model.load(name="latest")
        if e.config.use_cuda:
            model.cuda()
            e.log.info("transferred model to gpu")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, best dev res: {:.3f}, test avg res: {:.3f}".format(
                start_epoch, true_it, best_dev_res, test_avg_res))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    if e.config.decoder_type.startswith("bag"):
        minibatcher = data_utils.bow_minibatcher
        e.log.info("using BOW batcher")
    else:
        minibatcher = data_utils.minibatcher
        e.log.info("using sequential batcher")

    train_batch = minibatcher(
        data1=data.train_data[0],
        data2=data.train_data[1],
        vocab_size=len(data.vocab),
        batch_size=e.config.batch_size,
        score_func=model.score,
        shuffle=True,
        mega_batch=0 if not e.config.resume else e.config.mb,
        p_scramble=e.config.ps)

    evaluator = train_helper.evaluator(model, e)

    e.log.info("Training start ...")
    train_stats = train_helper.tracker([
        "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss",
        "dp_loss"
    ])

    for epoch in range(start_epoch, e.config.n_epoch):
        if epoch > 1 and train_batch.mega_batch != e.config.mb:
            train_batch.mega_batch = e.config.mb
            train_batch._reset()
        e.log.info("current mega batch: {}".format(train_batch.mega_batch))
        for it, (s1, m1, s2, m2, t1, tm1, t2, tm2,
                 n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2, _) in \
                enumerate(train_batch):
            true_it = it + 1 + epoch * len(train_batch)

            loss, vkl, gkl, rec_logloss, para_logloss, wploss, dploss = \
                model(s1, m1, s2, m2, t1, tm1, t2, tm2,
                      n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2,
                      e.config.vmkl, e.config.gmkl,
                      epoch > 1 and e.config.dratio and e.config.mb > 1)
            model.optimize(loss)

            train_stats.update(
                {
                    "loss": loss,
                    "vmf_kl": vkl,
                    "gauss_kl": gkl,
                    "para_logloss": para_logloss,
                    "rec_logloss": rec_logloss,
                    "wploss": wploss,
                    "dp_loss": dploss
                }, len(s1))

            if (true_it + 1) % e.config.print_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                summarization = train_stats.summarize(
                    "epoch: {}, it: {} (max: {}), kl_temp: {:.2E}|{:.2E}".
                    format(epoch, it, len(train_batch), e.config.vmkl,
                           e.config.gmkl))
                e.log.info(summarization)
                if e.config.summarize:
                    for name, value in train_stats.stats.items():
                        writer.add_scalar("train/" + name, value, true_it)
                train_stats.reset()

            if (true_it + 1) % e.config.eval_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                dev_stats, _, dev_res, _ = evaluator.evaluate(
                    data.dev_data, 'pred')

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                if e.config.summarize:
                    writer.add_scalar("dev/pearson", dev_stats[EVAL_YEAR][1],
                                      true_it)
                    writer.add_scalar("dev/spearman", dev_stats[EVAL_YEAR][2],
                                      true_it)

                if best_dev_res < dev_res:
                    best_dev_res = dev_res

                    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

                    test_stats, test_bm_res, test_avg_res, test_avg_s = \
                        evaluator.evaluate(data.test_data, 'pred')

                    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
                    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

                    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
                        evaluator.evaluate(data.test_data, 'predz')
                    e.log.info(
                        "Summary - benchmark: {:.4f}, test avg: {:.4f}".format(
                            tz_bm_res, tz_avg_res))
                    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

                    model.save(dev_avg=best_dev_res,
                               dev_perf=dev_stats,
                               test_avg=test_avg_res,
                               test_perf=test_stats,
                               iteration=true_it,
                               epoch=epoch)

                    if e.config.summarize:
                        for year, stats in test_stats.items():
                            writer.add_scalar("test/{}_pearson".format(year),
                                              stats[1], true_it)
                            writer.add_scalar("test/{}_spearman".format(year),
                                              stats[2], true_it)

                        writer.add_scalar("test/avg_pearson", test_avg_res,
                                          true_it)
                        writer.add_scalar("test/avg_spearman", test_avg_s,
                                          true_it)
                        writer.add_scalar("test/STSBenchmark_pearson",
                                          test_bm_res, true_it)
                        writer.add_scalar("dev/best_pearson", best_dev_res,
                                          true_it)

                        writer.add_scalar("testz/avg_pearson", tz_avg_res,
                                          true_it)
                        writer.add_scalar("testz/avg_spearman", tz_avg_s,
                                          true_it)
                        writer.add_scalar("testz/STSBenchmark_pearson",
                                          tz_bm_res, true_it)
                train_stats.reset()
                e.log.info("best dev result: {:.4f}, "
                           "STSBenchmark result: {:.4f}, "
                           "test average result: {:.4f}".format(
                               best_dev_res, test_bm_res, test_avg_res))

        model.save(dev_avg=best_dev_res,
                   dev_perf=dev_stats,
                   test_avg=test_avg_res,
                   test_perf=test_stats,
                   iteration=true_it,
                   epoch=epoch + 1,
                   name="latest")

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

    test_stats, test_bm_res, test_avg_res, test_avg_s = \
        evaluator.evaluate(data.test_data, 'pred')

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
        evaluator.evaluate(data.test_data, 'predz')
    e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}".format(
        tz_bm_res, tz_avg_res))
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)
Esempio n. 3
0
def run(e):
    global test_bm_res, test_avg_res

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(train_path=e.config.train_path,
                                   eval_path=e.config.eval_path,
                                   experiment=e)
    data, W = dp.process()

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.model.lower() == "basic":
        model_class = models.basic_model
    elif e.config.model.lower() == "pos":
        model_class = models.pos_model
    elif e.config.model.lower() == "quant_pos":
        model_class = models.quantize_pos_model
    elif e.config.model.lower() == "quant_pos_reg":
        model_class = models.quantize_pos_regression_model
    elif e.config.model.lower() == "quant_attn_pos1":
        model_class = models.quantize_attn_pos_model1
    elif e.config.model.lower() == "quant_attn_pos2":
        model_class = models.quantize_attn_pos_model2

    model = model_class(vocab_size=len(data.sent_vocab),
                        doc_title_vocab_size=len(data.doc_title_vocab),
                        sec_title_vocab_size=len(data.sec_title_vocab),
                        embed_dim=e.config.edim if W is None else W.shape[1],
                        embed_init=W,
                        max_nsent=MAX_NSENT,
                        max_npara=MAX_NPARA,
                        max_nlv=MAX_NLV,
                        experiment=e)

    start_epoch = it = n_finish_file = 0
    todo_file = list(range(len(data.train_data)))

    if e.config.resume:
        start_epoch, it, test_bm_res, test_avg_res, todo_file = \
            model.load()
        if e.config.use_cuda:
            model.cuda()
            e.log.info("transferred model to gpu")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, test benchmark {:.3f}, test avg res: {:.3f}".
            format(start_epoch, it, test_bm_res, test_avg_res))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    evaluator = train_helper.evaluator(model, e)

    e.log.info("Training start ...")
    train_stats = train_helper.tracker([
        "loss", "prev_logloss", "next_logloss", "para_loss", "sent_loss",
        "level_loss", "doc_title_loss", "sec_title_loss"
    ])

    for epoch in range(start_epoch, e.config.n_epoch):
        while len(todo_file):
            file_idx = np.random.randint(len(todo_file))
            train_file = data.train_data[todo_file[file_idx]]

            with data_utils.minibatcher(
                    train_file=train_file,
                    sent_vocab=data.sent_vocab,
                    doc_title_vocab=data.doc_title_vocab,
                    sec_title_vocab=data.sec_title_vocab,
                    batch_size=e.config.batch_size,
                    max_len=e.config.max_len,
                    max_nsent=MAX_NSENT,
                    max_npara=MAX_NPARA,
                    bow=e.config.decoder_type.lower() == "bag_of_words",
                    log=e.log) as train_batch:

                for doc_id, para_id, pmask, _, sent_id, \
                        smask, lv, s, m, t, tm, t2, tm2, dt, st in \
                        train_batch:
                    it += 1

                    loss, logloss1, logloss2, para_loss, sent_loss, \
                        level_loss, doc_title_loss, sec_title_loss = model(
                            s, m, t, tm, t2, tm2, doc_id, para_id,
                            pmask, sent_id, smask, lv, dt, st)

                    model.optimize(loss)

                    train_stats.update(
                        {
                            "loss": loss,
                            "prev_logloss": logloss1,
                            "next_logloss": logloss2,
                            "para_loss": para_loss,
                            "sent_loss": sent_loss,
                            "level_loss": level_loss,
                            "doc_title_loss": doc_title_loss,
                            "sec_title_loss": sec_title_loss
                        }, len(s))

                    if it % e.config.print_every == 0:
                        summarization = train_stats.summarize(
                            "epoch: {}, it: {}".format(epoch, it))
                        e.log.info(summarization)
                        if e.config.summarize:
                            for name, value in train_stats.stats.items():
                                writer.add_scalar("train/" + name, value, it)
                        train_stats.reset()

                    if it % e.config.eval_every == 0:

                        e.log.info("*" * 25 + " STS EVAL " + "*" * 25)

                        test_stats, test_bm_res, test_avg_res, test_avg_s = \
                            evaluator.evaluate(data.test_data, 'score_sts')

                        e.log.info("*" * 25 + " STS EVAL " + "*" * 25)

                        # model.save(
                        #     test_avg=test_avg_res,
                        #     test_bm=test_bm_res,
                        #     todo_file=train_batch.todo_file,
                        #     it=it,
                        #     epoch=epoch)

                        if e.config.summarize:
                            for year, stats in test_stats.items():
                                writer.add_scalar(
                                    "test/{}_pearson".format(year), stats[1],
                                    it)
                                writer.add_scalar(
                                    "test/{}_spearman".format(year), stats[2],
                                    it)

                            writer.add_scalar("test/avg_pearson", test_avg_res,
                                              it)
                            writer.add_scalar("test/avg_spearman", test_avg_s,
                                              it)
                            writer.add_scalar("test/STSBenchmark_pearson",
                                              test_bm_res, it)

                        e.log.info("STSBenchmark result: {:.4f}, "
                                   "test average result: {:.4f}".format(
                                       test_bm_res, test_avg_res))

            del todo_file[file_idx]

            n_finish_file += 1
            model.save(test_avg=test_avg_res,
                       test_bm=test_bm_res,
                       todo_file=todo_file,
                       it=it,
                       epoch=epoch if len(todo_file) else epoch + 1)

            time_per_file = e.elapsed_time / n_finish_file
            time_in_need = time_per_file * (e.config.n_epoch - epoch - 1) * \
                len(data.train_data) + time_per_file * len(todo_file)
            e.log.info("elapsed time: {:.2f}(h), "
                       "#finished file: {}, #todo file: {}, "
                       "time per file: {:.2f}(h), "
                       "time needed to finish: {:.2f}(h)".format(
                           e.elapsed_time, n_finish_file, len(todo_file),
                           time_per_file, time_in_need))

            if time_per_file + e.elapsed_time > 3.8 \
                    and e.config.auto_disconnect:
                exit(1)

        todo_file = list(range(len(data.train_data)))
Esempio n. 4
0
def run(e):
    global best_dev_bleu, test_bleu

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    dp = data_utils.data_processor(train_path=e.config.train_path,
                                   experiment=e)
    data, W = dp.process()

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = models.vgvae(vocab_size=len(data.vocab),
                         embed_dim=e.config.edim if W is None else W.shape[1],
                         embed_init=W,
                         experiment=e)

    start_epoch = true_it = 0
    best_dev_stats = test_stats = None
    if e.config.resume:
        start_epoch, _, best_dev_bleu, test_bleu = \
            model.load(name="latest")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, best dev bleu: {:.3f}, test bleu: {:.3f}, ".format(
                start_epoch, true_it, best_dev_bleu, test_bleu))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    train_batch = data_utils.minibatcher(data1=data.train_data[0],
                                         tag1=data.train_tag[0],
                                         data2=data.train_data[1],
                                         tag2=data.train_tag[1],
                                         tag_bucket=data.tag_bucket,
                                         vocab_size=len(data.vocab),
                                         batch_size=e.config.batch_size,
                                         shuffle=True,
                                         p_replace=e.config.wr,
                                         p_scramble=e.config.ps)

    dev_eval = train_helper.evaluator(e.config.dev_inp_path,
                                      e.config.dev_ref_path, model, data.vocab,
                                      data.inv_vocab, e)
    test_eval = train_helper.evaluator(e.config.test_inp_path,
                                       e.config.test_ref_path, model,
                                       data.vocab, data.inv_vocab, e)

    e.log.info("Training start ...")
    train_stats = train_helper.tracker([
        "loss", "vmf_kl", "gauss_kl", "rec_logloss", "para_logloss", "wploss"
    ])

    for epoch in range(start_epoch, e.config.n_epoch):
        for it, (s1, sr1, m1, s2, sr2, m2, t1, tm1, t2, tm2, _) in \
                enumerate(train_batch):
            true_it = it + 1 + epoch * len(train_batch)

            loss, kl, kl2, rec_logloss, para_logloss, wploss = \
                model(s1, sr1, m1, s2, sr2, m2, t1, tm1,
                      t2, tm2, e.config.vmkl, e.config.gmkl)
            model.optimize(loss)
            train_stats.update(
                {
                    "loss": loss,
                    "vmf_kl": kl,
                    "gauss_kl": kl2,
                    "para_logloss": para_logloss,
                    "rec_logloss": rec_logloss,
                    "wploss": wploss
                }, len(s1))

            if (true_it + 1) % e.config.print_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                summarization = train_stats.summarize(
                    "epoch: {}, it: {} (max: {}), kl_temp(v|g): {:.2E}|{:.2E}".
                    format(epoch, it, len(train_batch), e.config.vmkl,
                           e.config.gmkl))
                e.log.info(summarization)
                if e.config.summarize:
                    for name, value in train_stats.stats.items():
                        writer.add_scalar("train/" + name, value, true_it)
                train_stats.reset()

            if (true_it + 1) % e.config.eval_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                dev_stats, dev_bleu = dev_eval.evaluate("gen_dev")

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                if e.config.summarize:
                    for name, value in dev_stats.items():
                        writer.add_scalar("dev/" + name, value, true_it)

                if best_dev_bleu < dev_bleu:
                    best_dev_bleu = dev_bleu
                    best_dev_stats = dev_stats

                    e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                    test_stats, test_bleu = test_eval.evaluate("gen_test")

                    e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25)

                    model.save(dev_bleu=best_dev_bleu,
                               dev_stats=best_dev_stats,
                               test_bleu=test_bleu,
                               test_stats=test_stats,
                               iteration=true_it,
                               epoch=epoch)

                    if e.config.summarize:
                        for name, value in test_stats.items():
                            writer.add_scalar("test/" + name, value, true_it)

                e.log.info("best dev bleu: {:.4f}, test bleu: {:.4f}".format(
                    best_dev_bleu, test_bleu))

        model.save(dev_bleu=best_dev_bleu,
                   dev_stats=best_dev_stats,
                   test_bleu=test_bleu,
                   test_stats=test_stats,
                   iteration=true_it,
                   epoch=epoch + 1,
                   name="latest")

        time_per_epoch = (e.elapsed_time / (epoch - start_epoch + 1))
        time_in_need = time_per_epoch * (e.config.n_epoch - epoch - 1)
        e.log.info("elapsed time: {:.2f}(h), "
                   "time per epoch: {:.2f}(h), "
                   "time needed to finish: {:.2f}(h)".format(
                       e.elapsed_time, time_per_epoch, time_in_need))

        if time_per_epoch + e.elapsed_time > 3.7 and e.config.auto_disconnect:
            exit(1)

    test_gen_stats, test_res = test_eval.evaluate("gen_test")