예제 #1
0
def add_slot_embs_to_slu_embs(slot_embs_file, slu_embs_file):
    from src.slu.datareader import datareader
    from src.utils import init_experiment
    from config import get_params

    with open(slot_embs_file, "rb") as f:
        slot_embs_dict = pickle.load(f)
    slu_embs = np.load(slu_embs_file)

    params = get_params()
    logger = init_experiment(params, logger_filename=params.logger_filename)
    _, vocab = datareader(use_label_encoder=True)

    new_slu_embs = np.zeros(
        (vocab.n_words, 400))  # 400: word + char level embs

    # copy previous embeddings
    prev_length = len(slu_embs)
    new_slu_embs[:prev_length, :] = slu_embs

    for slot_name in slot_list:
        emb = None
        index = vocab.word2index[slot_name]
        if index < prev_length: continue
        for domain, slot_embs in slot_embs_dict.items():
            slot_list_based_on_domain = domain2slot[domain]
            if slot_name in slot_list_based_on_domain:
                slot_index = slot_list_based_on_domain.index(slot_name)
                emb = slot_embs[slot_index]
                break
        assert emb is not None
        new_slu_embs[index] = emb

    np.save("../data/snips/emb/slu_word_char_embs_with_slotembs.npy",
            new_slu_embs)
예제 #2
0
파일: slu_main.py 프로젝트: zliucr/coach
def main(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)
    
    # get dataloader
    dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(params.tgt_dm, params.batch_size, params.tr, params.n_samples)

    # build model
    binary_slutagger = BinarySLUTagger(params, vocab)
    slotname_predictor = SlotNamePredictor(params)
    binary_slutagger, slotname_predictor = binary_slutagger.cuda(), slotname_predictor.cuda()
    if params.tr:
        sent_repre_generator = SentRepreGenerator(params, vocab)
        sent_repre_generator = sent_repre_generator.cuda()

    if params.tr:
        slu_trainer = SLUTrainer(params, binary_slutagger, slotname_predictor, sent_repre_generator=sent_repre_generator)
    else:
        slu_trainer = SLUTrainer(params, binary_slutagger, slotname_predictor)
    
    for e in range(params.epoch):
        logger.info("============== epoch {} ==============".format(e+1))
        loss_bin_list, loss_slotname_list = [], []
        if params.tr:
            loss_tem0_list, loss_tem1_list = [], []
        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        if params.tr:
            for i, (X, lengths, y_bin, y_final, y_dm, templates, tem_lengths) in pbar:
                X, lengths, templates, tem_lengths = X.cuda(), lengths.cuda(), templates.cuda(), tem_lengths.cuda()
                loss_bin, loss_slotname, loss_tem0, loss_tem1 = slu_trainer.train_step(X, lengths, y_bin, y_final, y_dm, templates=templates, tem_lengths=tem_lengths, epoch=e)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                loss_tem0_list.append(loss_tem0)
                loss_tem1_list.append(loss_tem1)

                pbar.set_description("(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_slotname_list), np.mean(loss_tem0_list), np.mean(loss_tem1_list)))
        else:
            for i, (X, lengths, y_bin, y_final, y_dm) in pbar:
                X, lengths = X.cuda(), lengths.cuda()
                loss_bin, loss_slotname = slu_trainer.train_step(X, lengths, y_bin, y_final, y_dm)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                pbar.set_description("(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_slotname_list)))
            
        if params.tr:
            logger.info("Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_slotname_list), np.mean(loss_tem0_list), np.mean(loss_tem1_list)))
        else:
            logger.info("Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_slotname_list)))

        logger.info("============== Evaluate Epoch {} ==============".format(e+1))
        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(dataloader_val, istestset=False)
        logger.info("Eval on dev set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".format(bin_f1, final_f1))

        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(dataloader_test, istestset=True)
        logger.info("Eval on test set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".format(bin_f1, final_f1))

        if stop_training_flag == True:
            break
예제 #3
0
def train(params, lang="en"):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    # dataloader
    dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(
        params, lang=lang)

    # build model
    lstm = Lstm(params, vocab)
    intent_predictor = IntentPredictor(params)
    slot_predictor = SlotPredictor(params)
    lstm.cuda()
    intent_predictor.cuda()
    slot_predictor.cuda()

    # build trainer
    dialog_trainer = DialogTrainer(params, lstm, intent_predictor,
                                   slot_predictor)

    for e in range(params.epoch):
        logger.info("============== epoch %d ==============" % e)
        intent_loss_list, slot_loss_list = [], []

        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        for i, (X, lengths, y1, y2) in pbar:
            X, lengths, y1 = X.cuda(), lengths.cuda(), y1.cuda(
            )  # the length of y2 is different for each sequence
            intent_loss, slot_loss = dialog_trainer.train_step(
                e, X, lengths, y1, y2)
            intent_loss_list.append(intent_loss)
            slot_loss_list.append(slot_loss)

            pbar.set_description(
                "(Epoch {}) INTENT LOSS:{:.4f} SLOT LOSS:{:.4f}".format(
                    e, np.mean(intent_loss_list), np.mean(slot_loss_list)))

        logger.info(
            "Finish training epoch %d. Intent loss: %.4f. Slot loss: %.4f" %
            (e, np.mean(intent_loss_list), np.mean(slot_loss_list)))

        logger.info("============== Evaluate %d ==============" % e)

        intent_acc, slot_f1, stop_training_flag = dialog_trainer.evaluate(
            dataloader_val)
        logger.info(
            "Intent ACC: %.4f (Best Acc: %.4f). Slot F1: %.4f. (Best F1: %.4f)"
            % (intent_acc, dialog_trainer.best_intent_acc, slot_f1,
               dialog_trainer.best_slot_f1))

        if stop_training_flag == True:
            break

    logger.info("============== Final Test ==============")
    intent_acc, slot_f1, _ = dialog_trainer.evaluate(dataloader_test,
                                                     istestset=True)
    logger.info("Intent ACC: %.4f. Slot F1: %.4f." % (intent_acc, slot_f1))
def gen_embs_for_vocab():
    from src.datareader import datareader
    from src.utils import load_embedding, init_experiment
    from config import get_params
    params = get_params()
    logger = init_experiment(params, logger_filename=params.logger_filename)

    _, vocab = datareader()
    embedding = load_embedding(vocab, 300, "/data/sh/glove.6B.300d.txt",
                               "/data/sh/coachdata/snips/emb/oov_embs.txt")
    np.save("/data/sh/coachdata/snips/emb/slu_embs.npy", embedding)
예제 #5
0
def gen_embs_for_vocab():
    from src.slu.datareader import datareader
    from src.utils import load_embedding, init_experiment
    from config import get_params

    params = get_params()
    logger = init_experiment(params, logger_filename=params.logger_filename)

    _, vocab = datareader()
    embedding = load_embedding(vocab, 300, "PATH_OF_THE_WIKI_EN_VEC",
                               "../data/snips/emb/oov_embs.txt")
    np.save("../data/snips/emb/slu_embs.npy", embedding)
예제 #6
0
def run_baseline(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    # get dataloader
    dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(
        params.batch_size,
        bilstmcrf=params.bilstmcrf,
        n_samples=params.n_samples)

    # build model
    if params.bilstmcrf:
        ner_tagger = BiLSTMCRFTagger(params, vocab)
        ner_tagger.cuda()
        baseline_trainer = BiLSTMCRFTrainer(params, ner_tagger)
    else:
        concept_tagger = ConceptTagger(params, vocab)
        concept_tagger.cuda()
        baseline_trainer = BaselineTrainer(params, concept_tagger)

    for e in range(params.epoch):
        logger.info("============== epoch {} ==============".format(e + 1))
        loss_list = []
        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        for i, (X, lengths, y) in pbar:
            X, lengths = X.cuda(), lengths.cuda()

            loss = baseline_trainer.train_step(X, lengths, y)
            loss_list.append(loss)
            pbar.set_description("(Epoch {}) LOSS:{:.4f}".format(
                (e + 1), np.mean(loss_list)))

        logger.info("Finish training epoch {}. LOSS:{:.4f}".format(
            (e + 1), np.mean(loss_list)))

        logger.info(
            "============== Evaluate Epoch {} ==============".format(e + 1))
        f1_score, stop_training_flag = baseline_trainer.evaluate(
            dataloader_val, istestset=False)
        logger.info("Eval on dev set. Entity-F1: {:.4f}.".format(f1_score))

        f1_score, stop_training_flag = baseline_trainer.evaluate(
            dataloader_test, istestset=True)
        logger.info("Eval on test set. Entity-F1: {:.4f}.".format(f1_score))

        if stop_training_flag == True:
            break
예제 #7
0
def transfer(params, trans_lang):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)
    logger.info("============== Evaluate Zero-Shot on %s ==============" %
                trans_lang)

    # dataloader
    _, _, dataloader_test, vocab = get_dataloader(params, lang=trans_lang)

    # get word embedding
    emb_file = params.emb_file_es if trans_lang == "es" else params.emb_file_th
    embedding = load_embedding(vocab, params.emb_dim, emb_file)

    # evaluate zero-shot
    evaluate_transfer = EvaluateTransfer(params, dataloader_test, embedding,
                                         vocab.n_words)
    intent_acc, slot_f1 = evaluate_transfer.evaluate()
    logger.info("Intent ACC: %.4f. Slot F1: %.4f." % (intent_acc, slot_f1))
예제 #8
0
def combine_word_with_char_embs_for_vocab(wordembs_file):
    from src.slu.datareader import datareader
    from src.utils import init_experiment
    from config import get_params
    import torchtext
    char_ngram_model = torchtext.vocab.CharNGram()

    params = get_params()
    logger = init_experiment(params, logger_filename=params.logger_filename)

    _, vocab = datareader()
    embedding = np.load(wordembs_file)

    word_char_embs = np.zeros((vocab.n_words, 400))
    for index, word in vocab.index2word.items():
        word_emb = embedding[index]
        char_emb = char_ngram_model[word].squeeze(0).numpy()
        word_char_embs[index] = np.concatenate((word_emb, char_emb), axis=-1)

    np.save("../data/snips/emb/slu_word_char_embs.npy", word_char_embs)
예제 #9
0
def main(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    if params.pretr_la_enc == True and params.ckpt_labelenc == "":
        dataloader_pretr, vocab_en = dataloader4pretr(params)
        slu_model4pretr = ModelSLU4Pretr(params, vocab_en)
        slu_model4pretr.cuda()
        trainer4pretr = SLUTrainer(params, slu_model4pretr, pretrain_flag=True)

        # pretraining label encoder
        logger.info("============== Pretraining Label Encoder ==============")
        for e in range(params.pretr_epoch):
            logger.info("============== epoch %d ==============" % e)
            pbar = tqdm(enumerate(dataloader_pretr),
                        total=len(dataloader_pretr))
            intent_loss_list, slot_loss_list, cos_loss_list = [], [], []
            for i, (x_1, padded_y2_1, lengths_1, y1_1, y2_1, x_2, padded_y2_2,
                    lengths_2, y1_2, y2_2) in pbar:
                x_1, lengths_1, y1_1 = x_1.cuda(), lengths_1.cuda(), y1_1.cuda(
                )
                x_2, lengths_2, y1_2 = x_2.cuda(), lengths_2.cuda(), y1_2.cuda(
                )

                padded_y2_1, padded_y2_2 = padded_y2_1.cuda(
                ), padded_y2_2.cuda()
                intent_loss_1, slot_loss_1, intent_loss_2, slot_loss_2, cos_loss = trainer4pretr.pretr_label_encoder_step(
                    e, x_1, lengths_1, x_2, lengths_2, y1_1, y2_1, y1_2, y2_2,
                    padded_y2_1, padded_y2_2)

                intent_loss_list.append(intent_loss_1)
                intent_loss_list.append(intent_loss_2)
                slot_loss_list.append(slot_loss_1)
                slot_loss_list.append(slot_loss_2)
                cos_loss_list.append(cos_loss)

                pbar.set_description(
                    "(Epoch {}) INTENT:{:.3f} SLOT:{:.3f} COS:{:.3f}".format(
                        (e + 1), np.mean(intent_loss_list),
                        np.mean(slot_loss_list), np.mean(cos_loss_list)))

            logger.info(
                "(Finished epoch {}) INTENT:{:.3f} SLOT:{:.3f} COS:{:.3f}".
                format((e + 1), np.mean(intent_loss_list),
                       np.mean(slot_loss_list), np.mean(cos_loss_list)))

        label_encoder_saved_path = os.path.join(params.dump_path,
                                                "label_encoder.pth")
        logger.info("Saving label encoder to %s" % label_encoder_saved_path)
        torch.save(slu_model4pretr.label_encoder, label_encoder_saved_path)

    # get dataloader and vocabulary
    dataloader_tr, dataloader_val, dataloader_test, vocab_en, vocab_trans = get_dataloader(
        params)

    # build model
    if params.adv == True:
        slu_model = ModelSLU4Adv(params, vocab_en, vocab_trans)
    else:
        slu_model = ModelSLU(params, vocab_en, vocab_trans)
    slu_model.cuda()

    if params.pretr_la_enc == True:
        # copy label encoder
        if params.adv == True:
            if params.ckpt_labelenc != "":
                logger.info("Loading label encoder from %s" %
                            params.ckpt_labelenc)
                pretrained_label_encoder = torch.load(params.ckpt_labelenc)
                pretrained_label_encoder = pretrained_label_encoder.cuda()
                slu_model.model.label_encoder = pretrained_label_encoder
            else:
                slu_model.model.label_encoder = slu_model4pretr.label_encoder
        else:
            if params.ckpt_labelenc != "":
                logger.info("Loading label encoder from %s" %
                            params.ckpt_labelenc)
                pretrained_label_encoder = torch.load(params.ckpt_labelenc)
                pretrained_label_encoder = pretrained_label_encoder.cuda()
                slu_model.label_encoder = pretrained_label_encoder
            else:
                slu_model.label_encoder = slu_model4pretr.label_encoder

    def get_learnable_params(module):
        return [p for p in module.parameters() if p.requires_grad]

    model_params = get_learnable_params(slu_model)
    print("model parameters: %d" % sum(p.numel() for p in model_params))
    # build trainer
    slu_trainer = SLUTrainer(params, slu_model)

    logger.info("============== Start training ==============")
    for e in range(params.epoch):
        logger.info("============== epoch %d ==============" % e)

        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        if params.tar_only == False:

            intent_loss_en_list, slot_loss_en_list, intent_loss_trans_list, slot_loss_trans_list, cos_loss_list = [], [], [], [], []
            if params.adv == True:
                model_en_adv_loss_list, model_trans_adv_loss_list, udg_en_adv_loss_list, udg_trans_adv_loss_list = [], [], [], []
                if params.intent_adv == True:
                    model_en_adv_intent_loss_list, model_trans_adv_intent_loss_list, udg_en_adv_intent_loss_list, udg_trans_adv_intent_loss_list = [], [], [], []

            for i, (x_en, padded_y2_en, lengths_en, y1_en, y2_en, x_trans,
                    padded_y2_trans, lengths_trans, y1_trans,
                    y2_trans) in pbar:
                x_en, lengths_en, y1_en = x_en.cuda(), lengths_en.cuda(
                ), y1_en.cuda()
                x_trans, lengths_trans, y1_trans = x_trans.cuda(
                ), lengths_trans.cuda(), y1_trans.cuda()
                if params.la_reg == False:
                    if params.adv == True:
                        # adversarial lvm
                        if params.intent_adv == True:
                            intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans, model_en_adv_loss, model_trans_adv_loss, udg_en_adv_loss, udg_trans_adv_loss, model_en_adv_intent_loss, model_trans_adv_intent_loss, udg_en_adv_intent_loss, udg_trans_adv_intent_loss = slu_trainer.joint_train_step(
                                e, x_en, lengths_en, x_trans, lengths_trans,
                                y1_en, y2_en, y1_trans, y2_trans)

                            model_en_adv_intent_loss_list.append(
                                model_en_adv_intent_loss)
                            model_trans_adv_intent_loss_list.append(
                                model_trans_adv_intent_loss)
                            udg_en_adv_intent_loss_list.append(
                                udg_en_adv_intent_loss)
                            udg_trans_adv_intent_loss_list.append(
                                udg_trans_adv_intent_loss)
                        else:
                            intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans, model_en_adv_loss, model_trans_adv_loss, udg_en_adv_loss, udg_trans_adv_loss = slu_trainer.joint_train_step(
                                e, x_en, lengths_en, x_trans, lengths_trans,
                                y1_en, y2_en, y1_trans, y2_trans)

                        model_en_adv_loss_list.append(model_en_adv_loss)
                        model_trans_adv_loss_list.append(model_trans_adv_loss)
                        udg_en_adv_loss_list.append(udg_en_adv_loss)
                        udg_trans_adv_loss_list.append(udg_trans_adv_loss)
                    else:
                        intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans = slu_trainer.joint_train_step(
                            e, x_en, lengths_en, x_trans, lengths_trans, y1_en,
                            y2_en, y1_trans, y2_trans)
                else:
                    padded_y2_en, padded_y2_trans = padded_y2_en.cuda(
                    ), padded_y2_trans.cuda()
                    if params.adv == True:
                        # adversarial lvm
                        if params.intent_adv == True:
                            intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans, cos_loss, model_en_adv_loss, model_trans_adv_loss, udg_en_adv_loss, udg_trans_adv_loss, model_en_adv_intent_loss, model_trans_adv_intent_loss, udg_en_adv_intent_loss, udg_trans_adv_intent_loss = slu_trainer.joint_train_step(
                                e, x_en, lengths_en, x_trans, lengths_trans,
                                y1_en, y2_en, y1_trans, y2_trans, padded_y2_en,
                                padded_y2_trans)

                            model_en_adv_intent_loss_list.append(
                                model_en_adv_intent_loss)
                            model_trans_adv_intent_loss_list.append(
                                model_trans_adv_intent_loss)
                            udg_en_adv_intent_loss_list.append(
                                udg_en_adv_intent_loss)
                            udg_trans_adv_intent_loss_list.append(
                                udg_trans_adv_intent_loss)
                        else:
                            intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans, cos_loss, model_en_adv_loss, model_trans_adv_loss, udg_en_adv_loss, udg_trans_adv_loss = slu_trainer.joint_train_step(
                                e, x_en, lengths_en, x_trans, lengths_trans,
                                y1_en, y2_en, y1_trans, y2_trans, padded_y2_en,
                                padded_y2_trans)

                        model_en_adv_loss_list.append(model_en_adv_loss)
                        model_trans_adv_loss_list.append(model_trans_adv_loss)
                        udg_en_adv_loss_list.append(udg_en_adv_loss)
                        udg_trans_adv_loss_list.append(udg_trans_adv_loss)
                    else:
                        intent_loss_en, slot_loss_en, intent_loss_trans, slot_loss_trans, cos_loss = slu_trainer.joint_train_step(
                            e, x_en, lengths_en, x_trans, lengths_trans, y1_en,
                            y2_en, y1_trans, y2_trans, padded_y2_en,
                            padded_y2_trans)

                    cos_loss_list.append(cos_loss)

                intent_loss_en_list.append(intent_loss_en)
                slot_loss_en_list.append(slot_loss_en)
                intent_loss_trans_list.append(intent_loss_trans)
                slot_loss_trans_list.append(slot_loss_trans)

                if params.la_reg == False:
                    if params.adv == True:
                        if params.intent_adv == True:
                            pbar.set_description(
                                "(E{})I1:{:.1f}S1:{:.1f}I2:{:.1f}S2:{:.1F}G1:{:.2f}G2:{:.2f}G3:{:.2f}G4:{:.2f}D1:{:.2f}D2:{:.2f}D3:{:.2f}D4:{:.2f}"
                                .format(
                                    (e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(model_en_adv_intent_loss_list),
                                    np.mean(model_trans_adv_intent_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list),
                                    np.mean(udg_en_adv_intent_loss_list),
                                    np.mean(udg_trans_adv_intent_loss_list)))
                        else:
                            pbar.set_description(
                                "(E{}) I1:{:.4f} S1:{:.4f} I2:{:.4f} S2:{:.4F} G1:{:.4f} G2:{:.4f} D1:{:.4f} D2:{:.4f}"
                                .format((e + 1), np.mean(intent_loss_en_list),
                                        np.mean(slot_loss_en_list),
                                        np.mean(intent_loss_trans_list),
                                        np.mean(slot_loss_trans_list),
                                        np.mean(model_en_adv_loss_list),
                                        np.mean(model_trans_adv_loss_list),
                                        np.mean(udg_en_adv_loss_list),
                                        np.mean(udg_trans_adv_loss_list)))
                    else:
                        pbar.set_description(
                            "(Epoch {}) EN_INTENT:{:.4f} EN_SLOT:{:.4f} TRANS_INTENT:{:.4f} TRANS_SLOT:{:.4F}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list)))
                else:
                    if params.adv == True:
                        if params.intent_adv == True:
                            pbar.set_description(
                                "(E{})I1:{:.1f}S1:{:.1f}I2:{:.1f}S2:{:.1F}C:{:.1f}G1:{:.2f}G2:{:.2f}G3:{:.2f}G4:{:.2f}D1:{:.2f}D2:{:.2f}D3:{:.2f}D4:{:.2f}"
                                .format(
                                    (e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(cos_loss_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(model_en_adv_intent_loss_list),
                                    np.mean(model_trans_adv_intent_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list),
                                    np.mean(udg_en_adv_intent_loss_list),
                                    np.mean(udg_trans_adv_intent_loss_list)))
                        else:
                            pbar.set_description(
                                "(E{}) I1:{:.3f} S1:{:.3f} I2:{:.3f} S2:{:.3f} C:{:.3f} G1:{:.3f} G2:{:.3f} D1:{:.3f} D2:{:.3f}"
                                .format((e + 1), np.mean(intent_loss_en_list),
                                        np.mean(slot_loss_en_list),
                                        np.mean(intent_loss_trans_list),
                                        np.mean(slot_loss_trans_list),
                                        np.mean(cos_loss_list),
                                        np.mean(model_en_adv_loss_list),
                                        np.mean(model_trans_adv_loss_list),
                                        np.mean(udg_en_adv_loss_list),
                                        np.mean(udg_trans_adv_loss_list)))
                    else:
                        pbar.set_description(
                            "(Epoch {}) EN_INTENT:{:.3f} EN_SLOT:{:.3f} TR_INTENT:{:.3f} TR_SLOT:{:.3f} COS:{:.3f}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(cos_loss_list)))
        else:
            intent_loss_list, slot_loss_list = [], []
            for i, (X, lengths, y1, y2) in pbar:
                X, lengths, y1 = X.cuda(), lengths.cuda(), y1.cuda()
                intent_loss, slot_loss = slu_trainer.single_train_step(
                    X, lengths, y1, y2)

                intent_loss_list.append(intent_loss)
                slot_loss_list.append(slot_loss)

                pbar.set_description(
                    "(Epoch {}) INTENT LOSS:{:.4f} SLOT LOSS:{:.4f}".format(
                        (e + 1), np.mean(intent_loss_list),
                        np.mean(slot_loss_list)))

        if params.tar_only == False:
            if params.la_reg == False:
                if params.adv == True:
                    if params.intent_adv == True:
                        logger.info(
                            "(E{})I1:{:.1f}S1:{:.1f}I2:{:.1f}S2:{:.1f}G1:{:.2f}G2:{:.2f}G3:{:.2f}G4:{:.2f}D1:{:.2f}D2:{:.2f}D3:{:.2f}D4:{:.2f}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(model_en_adv_intent_loss_list),
                                    np.mean(model_trans_adv_intent_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list),
                                    np.mean(udg_en_adv_intent_loss_list),
                                    np.mean(udg_trans_adv_intent_loss_list)))
                    else:
                        logger.info(
                            "(E{}) I1:{:.4f} S1:{:.4f} I2:{:.4f} S2:{:.4F} G1:{:.4f} G2:{:.4f} D1:{:.4f} D2:{:.4f}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list)))
                else:
                    logger.info(
                        "Finish training epoch {} EN_INTENT:{:.4f} EN_SLOT:{:.4f} TRANS_INTENT:{:.4f} TRANS_SLOT:{:.4f}"
                        .format((e + 1), np.mean(intent_loss_en_list),
                                np.mean(slot_loss_en_list),
                                np.mean(intent_loss_trans_list),
                                np.mean(slot_loss_trans_list)))
            else:
                if params.adv == True:
                    if params.intent_adv == True:
                        logger.info(
                            "(E{})I1:{:.1f}S1:{:.1f}I2:{:.1f}S2:{:.1f}C:{:.1f}G1:{:.2f}G2:{:.2f}G3:{:.2f}G4:{:.2f}D1:{:.2f}D2:{:.2f}D3:{:.2f}D4:{:.2f}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(cos_loss_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(model_en_adv_intent_loss_list),
                                    np.mean(model_trans_adv_intent_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list),
                                    np.mean(udg_en_adv_intent_loss_list),
                                    np.mean(udg_trans_adv_intent_loss_list)))
                    else:
                        logger.info(
                            "(E{}) I1:{:.4f} S1:{:.4f} I2:{:.4f} S2:{:.4f} C:{:.4f} G1:{:.4f} G2:{:.4f} D1:{:.4f} D2:{:.4f}"
                            .format((e + 1), np.mean(intent_loss_en_list),
                                    np.mean(slot_loss_en_list),
                                    np.mean(intent_loss_trans_list),
                                    np.mean(slot_loss_trans_list),
                                    np.mean(cos_loss_list),
                                    np.mean(model_en_adv_loss_list),
                                    np.mean(model_trans_adv_loss_list),
                                    np.mean(udg_en_adv_loss_list),
                                    np.mean(udg_trans_adv_loss_list)))
                else:
                    logger.info(
                        "Finish training epoch {} EN_INTENT:{:.3f} EN_SLOT:{:.3f} TR_INTENT:{:.3f} TR_SLOT:{:.3f} COS:{:.3f}"
                        .format((e + 1), np.mean(intent_loss_en_list),
                                np.mean(slot_loss_en_list),
                                np.mean(intent_loss_trans_list),
                                np.mean(slot_loss_trans_list),
                                np.mean(cos_loss_list)))
        else:
            logger.info(
                "Finish training epoch {} INTENT LOSS:{:.4f} SLOT LOSS:{:.4f}".
                format((e + 1), np.mean(intent_loss_list),
                       np.mean(slot_loss_list)))

        logger.info(
            "============== Evaluate Epoch {} ==============".format(e + 1))
        intent_acc, slot_f1, stop_training_flag = slu_trainer.evaluate(
            dataloader_val, istestset=False)

        logger.info(
            "Dev Set: Intent ACC:{:.4f} (Best Acc:{:.4f}). Slot F1:{:.4f}. (Best F1:{:.4f})"
            .format(intent_acc, slu_trainer.best_intent_acc, slot_f1,
                    slu_trainer.best_slot_f1))

        intent_acc, slot_f1, _ = slu_trainer.evaluate(dataloader_test,
                                                      istestset=True)

        logger.info(
            "Test set: Intent ACC:{:.4f} (Best Acc:{:.4f}). Slot F1:{:.4f}. (Best F1:{:.4f})"
            .format(intent_acc, slu_trainer.best_intent_acc, slot_f1,
                    slu_trainer.best_slot_f1))

        if stop_training_flag == True:
            break
예제 #10
0
def train_nlu(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    # dataloader
    dataloader_tr, dataloader_val, dataloader_test, vocab_en, vocab_trans = get_nlu_dataloader(
        params)

    # build model
    lstm = Lstm_nlu(params, vocab_en, vocab_trans)

    intent_predictor = IntentPredictor(params)
    slot_predictor = SlotPredictor(params)
    lstm.cuda()
    intent_predictor.cuda()
    slot_predictor.cuda()

    # build trainer
    nlu_trainer = NLU_Trainer(params, lstm, intent_predictor, slot_predictor)

    for e in range(params.epoch):
        logger.info("============== epoch {} ==============".format(e + 1))
        intent_loss_list, slot_loss_list = [], []

        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        for i, (X, lengths, y1, y2) in pbar:
            X, lengths, y1 = X.cuda(), lengths.cuda(), y1.cuda(
            )  # the length of y2 is different for each sequence

            intent_loss, slot_loss = nlu_trainer.train_step(X, lengths, y1, y2)
            intent_loss_list.append(intent_loss)
            slot_loss_list.append(slot_loss)

            pbar.set_description(
                "(Epoch {}) INTENT LOSS:{:.4f} SLOT LOSS:{:.4f}".format(
                    e + 1, np.mean(intent_loss_list), np.mean(slot_loss_list)))

        logger.info(
            "Finish training epoch {}. Intent loss: {:.4f}. Slot loss: {:.4f}".
            format(e + 1, np.mean(intent_loss_list), np.mean(slot_loss_list)))

        logger.info("============== Evaluate %d ==============" % e)
        intent_acc, slot_f1, stop_training_flag = nlu_trainer.evaluate(
            dataloader_val)
        logger.info(
            "({}) Intent ACC: {:.4f} (Best Acc: {:.4f}). Slot F1: {:.4f}. (Best F1: {:.4f})"
            .format(params.trans_lang, intent_acc, nlu_trainer.best_intent_acc,
                    slot_f1, nlu_trainer.best_slot_f1))

        intent_acc, slot_f1, _ = nlu_trainer.evaluate(dataloader_test,
                                                      istestset=True)
        logger.info("({}) Intent ACC: {:.4f}. Slot F1: {:.4f}.".format(
            params.trans_lang, intent_acc, slot_f1))

        if stop_training_flag == True:
            break

    logger.info("============== Final Test ==============")
    intent_acc, slot_f1, _ = nlu_trainer.evaluate(dataloader_test,
                                                  istestset=True,
                                                  load_best_model=True)
    logger.info("Intent ACC: {:.4f}. Slot F1: {:.4f}.".format(
        intent_acc, slot_f1))
예제 #11
0
def train_dst(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    with codecs.open(params.ontology_class_path, 'r', 'utf8') as f:
        dialogue_ontology = json.load(f)

    # get vocab and dialogue_ontology
    with open(params.vocab_path_en, "rb") as f:
        vocab_en = pickle.load(f)
    with open(params.vocab_path_trans, "rb") as f:
        vocab_trans = pickle.load(f)

    # dataloader
    dataloader_tr, dataloader_val, dataloader_test = get_dst_dataloader(
        params, vocab_en, vocab_trans, dialogue_ontology)
    dst_model = DialogueStateTracker(params, vocab_en, vocab_trans)
    dst_model.cuda()

    # build trainer
    dst_trainer = DST_Trainer(params, dst_model)

    for e in range(params.epoch):
        logger.info("============== epoch {} ==============".format(e + 1))
        food_loss_list, price_loss_list, area_loss_list, request_loss_list = [], [], [], []

        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        for i, (_, utters, lengths, acts_request, acts_slot, acts_values,
                slot_names, turn_slot_labels, turn_request_labels) in pbar:
            turn_slot_labels, turn_request_labels = turn_slot_labels.cuda(
            ), turn_request_labels.cuda()
            utters, lengths = utters.cuda(), lengths.cuda()

            food_loss, price_loss, area_loss, request_loss = dst_trainer.train_step(
                utters, lengths, acts_request, acts_slot, acts_values,
                slot_names, turn_slot_labels, turn_request_labels)

            food_loss_list.append(food_loss)
            price_loss_list.append(price_loss)
            area_loss_list.append(area_loss)
            request_loss_list.append(request_loss)

            pbar.set_description(
                "(Epoch {}) FOOD:{:.4f} PRICE:{:.4f} AREA:{:.4f} REQUEST:{:.4f}"
                .format(e + 1, np.mean(food_loss), np.mean(price_loss),
                        np.mean(area_loss), np.mean(request_loss)))

        logger.info(
            "Finish training epoch {}. FOOD:{:.4f} PRICE:{:.4f} AREA:{:.4f} REQUEST:{:.4f}"
            .format(e + 1, np.mean(food_loss), np.mean(price_loss),
                    np.mean(area_loss), np.mean(request_loss)))

        logger.info("============== Evaluate {} ==============".format(e + 1))
        goal_acc, request_acc, joint_goal_acc, avg_acc, stop_training_flag = dst_trainer.evaluate(
            dataloader_val, isTestset=False)
        logger.info(
            "({}) Goal ACC: {:.4f}. Joint ACC: {:.4f}. Request ACC: {:.4f}. Avg ACC: {:.4f} (Best Avg Acc: {:.4f})"
            .format(params.trans_lang, goal_acc, joint_goal_acc, request_acc,
                    avg_acc, dst_trainer.best_avg_acc))

        goal_acc, request_acc, joint_goal_acc, avg_acc, _ = dst_trainer.evaluate(
            dataloader_test, isTestset=True)
        logger.info(
            "({}) Goal ACC: {:.4f}. Joint ACC: {:.4f}. Request ACC: {:.4f}. Avg ACC: {:.4f}"
            .format(params.trans_lang, goal_acc, joint_goal_acc, request_acc,
                    avg_acc))

        if stop_training_flag == True:
            break

    logger.info("============== Final Test ==============")
    goal_acc, request_acc, joint_goal_acc, avg_acc, _ = dst_trainer.evaluate(
        dataloader_test, isTestset=True, load_best_model=True)
    logger.info(
        "Goal ACC: {:.4f}. Joint ACC: {:.4f}. Request ACC: {:.4f}. Avg ACC: {:.4f})"
        .format(goal_acc, joint_goal_acc, request_acc, avg_acc))
예제 #12
0
def train(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)
    
    if params.bilstm:
        # dataloader
        dataloader_train, dataloader_dev, dataloader_test, vocab = get_dataloader_for_bilstmtagger(params)
        # bilstm-crf model
        model = BiLSTMTagger(params, vocab)
        model.cuda()
        # trainer
        trainer = BaseTrainer(params, model)
    elif params.coach:
        # dataloader
        dataloader_train, dataloader_dev, dataloader_test, vocab = get_dataloader_for_coach(params)
        # coach model
        binary_tagger = BiLSTMTagger(params, vocab)
        entity_predictor = EntityPredictor(params)
        binary_tagger.cuda()
        entity_predictor.cuda()
        # trainer
        trainer = CoachTrainer(params, binary_tagger, entity_predictor)
    else:
        # dataloader
        dataloader_train, dataloader_dev, dataloader_test = get_dataloader(params)
        # BERT-based NER Tagger
        model = BertTagger(params)
        model.cuda()
        # trainer
        trainer = BaseTrainer(params, model)

    if params.conll and not params.joint:
        conll_trainloader, conll_devloader, conll_testloader = get_conll2003_dataloader(params.batch_size, params.tgt_dm)
        trainer.train_conll(conll_trainloader, conll_devloader, conll_testloader, params.tgt_dm)

    no_improvement_num = 0
    best_f1 = 0
    logger.info("Training on target domain ...")
    for e in range(params.epoch):
        logger.info("============== epoch %d ==============" % e)
        
        pbar = tqdm(enumerate(dataloader_train), total=len(dataloader_train))
        if params.bilstm:
            loss_list = []
            for i, (X, lengths, y) in pbar:
                X, lengths = X.cuda(), lengths.cuda()
                loss = trainer.train_step_for_bilstm(X, lengths, y)
                loss_list.append(loss)
                pbar.set_description("(Epoch {}) LOSS:{:.4f}".format(e, np.mean(loss_list)))

            logger.info("Finish training epoch %d. loss: %.4f" % (e, np.mean(loss_list)))

        elif params.coach:
            loss_bin_list, loss_entity_list = [], []
            for i, (X, lengths, y_bin, y_final) in pbar:
                X, lengths = X.cuda(), lengths.cuda()
                loss_bin, loss_entityname = trainer.train_step(X, lengths, y_bin, y_final)
                loss_bin_list.append(loss_bin)
                loss_entity_list.append(loss_entityname)
                pbar.set_description("(Epoch {}) LOSS BIN:{:.4f}; LOSS ENTITY:{:.4f}".format(e, np.mean(loss_bin_list), np.mean(loss_entity_list)))
            
            logger.info("Finish training epoch %d. loss_bin: %.4f. loss_entity: %.4f" % (e, np.mean(loss_bin_list), np.mean(loss_entity_list)))

        else:
            loss_list = []
            for i, (X, y) in pbar:
                X, y = X.cuda(), y.cuda()
                loss = trainer.train_step(X, y)
                loss_list.append(loss)
                pbar.set_description("(Epoch {}) LOSS:{:.4f}".format(e, np.mean(loss_list)))

            logger.info("Finish training epoch %d. loss: %.4f" % (e, np.mean(loss_list)))

        logger.info("============== Evaluate epoch %d on Train Set ==============" % e)
        f1_train = trainer.evaluate(dataloader_train, params.tgt_dm, use_bilstm=params.bilstm)
        logger.info("Evaluate on Train Set. F1: %.4f." % f1_train)

        logger.info("============== Evaluate epoch %d on Dev Set ==============" % e)
        f1_dev = trainer.evaluate(dataloader_dev, params.tgt_dm, use_bilstm=params.bilstm)
        logger.info("Evaluate on Dev Set. F1: %.4f." % f1_dev)

        logger.info("============== Evaluate epoch %d on Test Set ==============" % e)
        f1_test = trainer.evaluate(dataloader_test, params.tgt_dm, use_bilstm=params.bilstm)
        logger.info("Evaluate on Test Set. F1: %.4f." % f1_test)

        if f1_dev > best_f1:
            logger.info("Found better model!!")
            best_f1 = f1_dev
            no_improvement_num = 0
            # trainer.save_model()
        else:
            no_improvement_num += 1
            logger.info("No better model found (%d/%d)" % (no_improvement_num, params.early_stop))

        if no_improvement_num >= params.early_stop:
            break
예제 #13
0
def main(params):

    logger = init_experiment(params, logger_filename=params.logger_filename)

    dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(
        params.tgt_dm, params.batch_size, params.tr, params.n_samples)

    coarse_slutagger = CoarseSLUTagger(params, vocab)

    coarse_slutagger = coarse_slutagger.cuda()
    dm_coarse = get_coarse_labels_for_domains()

    fine_predictor = FinePredictor(params, dm_coarse)
    fine_predictor = fine_predictor.cuda()

    # if params.tr:
    sent_repre_generator = SentRepreGenerator(params, vocab)
    sent_repre_generator = sent_repre_generator.cuda()

    slu_trainer = SLUTrainer(params,
                             coarse_slutagger,
                             fine_predictor,
                             sent_repre_generator=sent_repre_generator)

    for e in range(params.epoch):
        loss_c_list = []
        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        logger.info("============== epoch {} ==============".format(e + 1))
        if e < params.pretrained_epoch or e == 7 or e == 8 or e == 12 or e == 13 \
        or e == 17 or e == 20:
            if params.tr:
                for i, (X, lengths, y_0, y_bin, y_final, y_dm, templates,
                        tem_lengths) in pbar:
                    X, lengths = X.cuda(), lengths.cuda()
                    loss_chunking = slu_trainer.chunking_pretrain(
                        X, lengths, y_0)
                    loss_c_list.append(loss_chunking)
                    pbar.set_description(
                        "(Epoch {}) LOSS CHUNKING:{:.4f}".format(
                            (e + 1), np.mean(loss_c_list)))

            else:
                for i, (X, lengths, y_0, y_bin, y_final, y_dm) in pbar:
                    X, lengths = X.cuda(), lengths.cuda()
                    loss_chunking = slu_trainer.chunking_pretrain(
                        X, lengths, y_0)
                    loss_c_list.append(loss_chunking)
                    pbar.set_description(
                        "(Epoch {}) LOSS CHUNKING:{:.4f}".format(
                            (e + 1), np.mean(loss_c_list)))

            logger.info(
                "============== Evaluate Epoch {} ==============".format(e +
                                                                         1))
            bin_f1 = slu_trainer.chunking_eval(dataloader_val)
            logger.info(
                "Eval on dev set. Binary Slot-F1: {:.4f}".format(bin_f1))

            bin_f1 = slu_trainer.chunking_eval(dataloader_test)
            logger.info(
                "Eval on test set. Binary Slot-F1: {:.4f}".format(bin_f1))

            continue

        loss_bin_list, loss_slotname_list = [], []
        if params.tr:
            loss_tem0_list, loss_tem1_list = [], []

        # record = int(len(dataloader_tr) / 4)
        if params.tr:
            for i, (X, lengths, y_0, y_bin, y_final, y_dm, templates,
                    tem_lengths) in pbar:
                X, lengths, templates, tem_lengths = X.cuda(), lengths.cuda(
                ), templates.cuda(), tem_lengths.cuda()
                loss_bin, loss_slotname, loss_tem0, loss_tem1 = slu_trainer.train_step(
                    X,
                    lengths,
                    y_bin,
                    y_final,
                    y_dm,
                    templates=templates,
                    tem_lengths=tem_lengths,
                    epoch=e)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                loss_tem0_list.append(loss_tem0)
                loss_tem1_list.append(loss_tem1)

                pbar.set_description(
                    "(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}"
                    .format((e + 1), np.mean(loss_bin_list),
                            np.mean(loss_slotname_list),
                            np.mean(loss_tem0_list), np.mean(loss_tem1_list)))
        else:
            for i, (X, lengths, y_0, y_bin, y_final, y_dm) in pbar:
                X, lengths = X.cuda(), lengths.cuda()
                loss_bin, loss_slotname = slu_trainer.train_step(
                    X, lengths, y_bin, y_final, y_dm)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                pbar.set_description(
                    "(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".format(
                        (e + 1), np.mean(loss_bin_list),
                        np.mean(loss_slotname_list)))

        if params.tr:
            logger.info(
                "Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}"
                .format((e + 1), np.mean(loss_bin_list),
                        np.mean(loss_slotname_list), np.mean(loss_tem0_list),
                        np.mean(loss_tem1_list)))
        else:
            logger.info(
                "Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".
                format((e + 1), np.mean(loss_bin_list),
                       np.mean(loss_slotname_list)))

        logger.info(
            "============== Evaluate Epoch {} ==============".format(e + 1))
        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(
            dataloader_val, istestset=False)
        logger.info(
            "Eval on dev set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".
            format(bin_f1, final_f1))

        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(
            dataloader_test, istestset=True)
        logger.info(
            "Eval on test set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".
            format(bin_f1, final_f1))

        if stop_training_flag == True:
            break
예제 #14
0
def test_coach(params):

    logger = init_experiment(params, logger_filename='test')
    # get dataloader
    dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(
        params.tgt_dm, params.batch_size, params.tr, params.n_samples)
    # _, _, dataloader_test, _ = get_dataloader(params.tgt_dm, params.batch_size, params.tr, params.n_samples)

    print(params.model_path)
    model_path = params.model_path
    opti_path = './experiments/coach_patience/atp_0/opti.pth'

    assert os.path.isfile(model_path)

    reloaded = torch.load(model_path)
    coarse_slutagger = CoarseSLUTagger(params, vocab)

    coarse_slutagger = coarse_slutagger.cuda()
    dm_coarse = get_coarse_labels_for_domains()

    fine_tagger = FinePredictor(params, dm_coarse)
    fine_tagger = fine_tagger.cuda()

    coarse_slutagger.load_state_dict(reloaded["coarse_tagger"])

    fine_tagger.load_state_dict(reloaded["fine_tagger"])
    coarse_tagger = coarse_slutagger.cuda()
    # fine_tagger.cuda()

    # model_parameters = [
    #             {"params": coarse_tagger.parameters()},
    #             {"params": fine_tagger.parameters()}
    #         ]

    # optimizer = torch.optim.Adam(model_parameters, lr=self.lr)

    # optimizer.load_state_dict(torch.load(opti_path))

    slu_trainer = SLUTrainer(params, coarse_tagger, fine_tagger)
    slu_trainer.optimizer.load_state_dict(torch.load(opti_path))

    for e in range(params.epoch):
        logger.info("============== epoch {} ==============".format(e + 1))
        loss_bin_list, loss_slotname_list = [], []
        if params.tr:
            loss_tem0_list, loss_tem1_list = [], []
        pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr))
        # record = int(len(dataloader_tr) / 4)
        if params.tr:
            for i, (X, lengths, y_bin, y_final, y_dm, templates,
                    tem_lengths) in pbar:
                X, lengths, templates, tem_lengths = X.cuda(), lengths.cuda(
                ), templates.cuda(), tem_lengths.cuda()
                loss_bin, loss_slotname, loss_tem0, loss_tem1 = slu_trainer.train_step(
                    X,
                    lengths,
                    y_bin,
                    y_final,
                    y_dm,
                    templates=templates,
                    tem_lengths=tem_lengths,
                    epoch=e)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                loss_tem0_list.append(loss_tem0)
                loss_tem1_list.append(loss_tem1)

                pbar.set_description(
                    "(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}"
                    .format((e + 1), np.mean(loss_bin_list),
                            np.mean(loss_slotname_list),
                            np.mean(loss_tem0_list), np.mean(loss_tem1_list)))
        else:
            for i, (X, lengths, y_bin, y_final, y_dm) in pbar:
                if i == 2:
                    break
                # if i %record == 0 and i > 0:
                #     logger.info("============== Evaluate Epoch {} {}==============".format(e+1, i))
                #     bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(dataloader_val, istestset=False)
                #     logger.info("Eval on dev set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".format(bin_f1, final_f1))

                #     bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(dataloader_test, istestset=True)
                #     logger.info("Eval on test set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".format(bin_f1, final_f1))
                X, lengths = X.cuda(), lengths.cuda()
                loss_bin, loss_slotname = slu_trainer.train_step(
                    X, lengths, y_bin, y_final, y_dm)
                loss_bin_list.append(loss_bin)
                loss_slotname_list.append(loss_slotname)
                pbar.set_description(
                    "(Epoch {}) LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".format(
                        (e + 1), np.mean(loss_bin_list),
                        np.mean(loss_slotname_list)))

        if params.tr:
            logger.info(
                "Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}"
                .format((e + 1), np.mean(loss_bin_list),
                        np.mean(loss_slotname_list), np.mean(loss_tem0_list),
                        np.mean(loss_tem1_list)))
        else:
            logger.info(
                "Finish training epoch {}. LOSS BIN:{:.4f} LOSS SLOT:{:.4f}".
                format((e + 1), np.mean(loss_bin_list),
                       np.mean(loss_slotname_list)))

        logger.info(
            "============== Evaluate Epoch {} ==============".format(e + 1))
        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(
            dataloader_val, istestset=False)
        logger.info(
            "Eval on dev set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".
            format(bin_f1, final_f1))

        bin_f1, final_f1, stop_training_flag = slu_trainer.evaluate(
            dataloader_test, istestset=True)
        logger.info(
            "Eval on test set. Binary Slot-F1: {:.4f}. Final Slot-F1: {:.4f}.".
            format(bin_f1, final_f1))

        if stop_training_flag == True:
            break