def bilstm_train_and_eval(train_data,
                          dev_data,
                          test_data,
                          word2id,
                          tag2id,
                          crf=True,
                          remove_O=False):
    train_word_lists, train_tag_lists = train_data
    dev_word_lists, dev_tag_lists = dev_data
    test_word_lists, test_tag_lists = test_data

    start = time.time()
    vocab_size = len(word2id)
    out_size = len(tag2id)
    bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
    bilstm_model.train(train_word_lists, train_tag_lists, dev_word_lists,
                       dev_tag_lists, word2id, tag2id)

    model_name = "bilstm_crf" if crf else "bilstm"
    save_model(bilstm_model, "./ckpts/" + model_name + ".pkl")

    print("Training completed, {} seconds when sharing.".format(
        int(time.time() - start)))
    print("Evaluation{} model:...".format(model_name))
    pred_tag_lists, test_tag_lists = bilstm_model.test(test_word_lists,
                                                       test_tag_lists, word2id,
                                                       tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
示例#2
0
def bilstm_train_and_eval(train_data,
                          dev_data,
                          test_data,
                          word2id,
                          tag2id,
                          output_dir,
                          crf=True,
                          remove_O=False):
    train_word_lists, train_tag_lists = train_data
    dev_word_lists, dev_tag_lists = dev_data
    test_word_lists, test_tag_lists = test_data

    start = time.time()
    vocab_size = len(word2id)
    out_size = len(tag2id)
    bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
    bilstm_model.train(train_word_lists, train_tag_lists, dev_word_lists,
                       dev_tag_lists, word2id, tag2id)

    model_name = "bilstm_crf" if crf else "bilstm"
    save_model(bilstm_model, os.path.join(output_dir, model_name + ".pkl"))

    print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
    print("评估{}模型中...".format(model_name))
    pred_tag_lists, test_tag_lists = bilstm_model.test(test_word_lists,
                                                       test_tag_lists, word2id,
                                                       tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
示例#3
0
def hmm_eval(hmm_model, test_data, word2id, tag2id, remove_O=False):
    """评估hmm模型"""
    test_word_lists, test_tag_lists = test_data
    pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
def main():
    print("读取数据...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    # dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    # test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("train", make_vocab=False)

    # print("加载并评估hmm模型...")
    # hmm_model = load_model(HMM_MODEL_PATH)
    # hmm_pred = hmm_model.test(test_word_lists,
    #                           word2id,
    #                           tag2id)
    # metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
    # metrics.report_scores()  # 打印每个标记的精确度、召回率、f1分数
    # metrics.report_confusion_matrix()  # 打印混淆矩阵

    # 加载并评估CRF模型
    # print("加载并评估crf模型...")
    # crf_model = load_model(CRF_MODEL_PATH)
    # crf_pred = crf_model.test(test_word_lists)
    # metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
    # metrics.report_scores()
    # metrics.report_confusion_matrix()

    # bilstm模型
    # print("加载并评估bilstm模型...")
    # bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
    # bilstm_model = load_model(BiLSTM_MODEL_PATH)
    # bilstm_model.model.bilstm.flatten_parameters()  # remove warning
    # lstm_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
    #                                                bilstm_word2id, bilstm_tag2id)
    # metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
    # metrics.report_scores()
    # metrics.report_confusion_matrix()

    print("加载并评估bilstm+crf模型...")
    crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
    bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
    bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
    test_word_lists = test_word_lists[:10]
    test_tag_lists = test_tag_lists[:10]
    test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
        test_word_lists, test_tag_lists, test=True)
    lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                      test_tag_lists,
                                                      crf_word2id, crf_tag2id)

    print(target_tag_list)
    print(lstmcrf_pred)

    metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()
示例#5
0
def main():
    print("读取数据...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
    dev_word_lists_, dev_word_lists_raw, article_id = loadDevFile("development_2.txt")

    print("加载并评估hmm模型...")
    hmm_model = load_model(HMM_MODEL_PATH)
    #hmm_pred = hmm_model.test(test_word_lists,
                              # word2id,
                              # tag2id)
    hmm_pred_dev = hmm_model.test(dev_word_lists_,
                              word2id,
                              tag2id)
    output_pred(hmm_pred_dev, article_id, dev_word_lists_raw)
    metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
    metrics.report_scores()  # 打印每个标记的精确度、召回率、f1分数
    metrics.report_confusion_matrix()  # 打印混淆矩阵

    # 加载并评估CRF模型
    print("加载并评估crf模型...")
    crf_model = load_model(CRF_MODEL_PATH)
    crf_pred = crf_model.test(test_word_lists)
    metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    # bilstm模型
    print("加载并评估bilstm模型...")
    bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
    bilstm_model = load_model(BiLSTM_MODEL_PATH)
    bilstm_model.model.bilstm.flatten_parameters()  # remove warning
    lstm_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
                                                   bilstm_word2id, bilstm_tag2id)
    metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    print("加载并评估bilstm+crf模型...")
    crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
    bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
    bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
    test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
        test_word_lists, test_tag_lists, test=True
    )
    lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
                                                      crf_word2id, crf_tag2id)
    metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    ensemble_evaluate(
        [hmm_pred, crf_pred, lstm_pred, lstmcrf_pred],
        test_tag_lists
    )
def ensemble_evaluate(results, targets, remove_O=False):
    """Multiple models of ensemble"""
    for i in range(len(results)):
        results[i] = flatten_lists(results[i])

    pred_tags = []
    for result in zip(*results):
        ensemble_tag = Counter(result).most_common(1)[0][0]
        pred_tags.append(ensemble_tag)

    targets = flatten_lists(targets)
    assert len(pred_tags) == len(targets)

    print("The results of the four Ensemble models are as follows:")
    metrics = Metrics(targets, pred_tags, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()
def crf_train_eval(train_data, test_data, remove_O=False):

    # training CRF model
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, "./ckpts/crf.pkl")

    pred_tag_lists = crf_model.test(test_word_lists)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
示例#8
0
def ensemble_evaluate(results, targets, remove_O=False):
    """ensemble多个模型"""
    for i in range(len(results)):
        results[i] = flatten_lists(results[i])

    pred_tags = []
    for result in zip(*results):
        ensemble_tag = Counter(result).most_common(1)[0][0]
        pred_tags.append(ensemble_tag)

    targets = flatten_lists(targets)
    assert len(pred_tags) == len(targets)

    print("Ensemble 四个模型的结果如下:")
    metrics = Metrics(targets, pred_tags, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()
示例#9
0
def crf_train_eval(train_data, test_data, output_dir, remove_O=False):

    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, os.path.join(output_dir, 'crf.pkl'))

    pred_tag_lists = crf_model.test(test_word_lists)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_O=False):
    """训练并评估hmm模型"""
    # 训练HMM模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    hmm_model = HMM(len(tag2id), len(word2id))
    hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
    save_model(hmm_model, "./ckpts/hmm.pkl")

    # 评估hmm模型
    pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_O=False):
    """ Train and evaluate the hmm model """
    # Training HMM model
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    hmm_model = HMM(len(tag2id), len(word2id))
    hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
    save_model(hmm_model, "./ckpts/hmm.pkl")

    #Evaluation of the hmm model
    pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
def main():
    print("Read data...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

    print("Load and evaluate the hmm model...")
    hmm_model = load_model(HMM_MODEL_PATH)
    hmm_pred = hmm_model.test(test_word_lists, word2id, tag2id)
    metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
    metrics.report_scores(
    )  # Print the accuracy of each mark, recall rate, f1 score
    metrics.report_confusion_matrix()  #Print confusion matrix

    # Load and evaluate the CRF model
    print("Load and evaluate the crf model...")
    crf_model = load_model(CRF_MODEL_PATH)
    crf_pred = crf_model.test(test_word_lists)
    metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    # bilstm Model
    print("Load and evaluate the bilstm model...")
    bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
    bilstm_model = load_model(BiLSTM_MODEL_PATH)
    bilstm_model.model.bilstm.flatten_parameters()  # remove warning
    lstm_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                   test_tag_lists,
                                                   bilstm_word2id,
                                                   bilstm_tag2id)
    metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    print("Load and evaluate the bilstm+crf model...")
    crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
    bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
    bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
    test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
        test_word_lists, test_tag_lists, test=True)
    lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                      test_tag_lists,
                                                      crf_word2id, crf_tag2id)
    metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    ensemble_evaluate([hmm_pred, crf_pred, lstm_pred, lstmcrf_pred],
                      test_tag_lists)
def bilstm_train_and_eval(train_data,
                          dev_data,
                          test_data,
                          word2id,
                          tag2id,
                          crf=True,
                          remove_O=False):
    train_word_lists, train_tag_lists = train_data
    dev_word_lists, dev_tag_lists = dev_data
    test_word_lists, test_tag_lists = test_data

    start = time.time()
    vocab_size = len(word2id)
    out_size = len(tag2id)
    bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
    bilstm_model.train(train_word_lists, train_tag_lists, dev_word_lists,
                       dev_tag_lists, word2id, tag2id)

    model_name = "bilstm_crf" if crf else "bilstm"
    save_model(bilstm_model, "./ckpts/" + model_name + ".pkl")

    print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
    print("评估{}模型中...".format(model_name))
    pred_tag_lists, test_tag_lists = bilstm_model.test(test_word_lists,
                                                       test_tag_lists, word2id,
                                                       tag2id)

    with open("./result.txt", "a+") as f:
        for i in range(len(pred_tag_lists)):
            f.write(pred_tag_lists[i] + " " + pred_tag_lists[i] + "\n")

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
def main_rep1(x, y):

    if x == 'train':
        # select data according to args.process
        print("Read data...")
        train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
        dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
        test_word_lists, test_tag_lists = build_corpus("test",
                                                       make_vocab=False)
        ######

        if y == 'crf':
            crf_pred = crf_train_eval((train_word_lists, train_tag_lists),
                                      (test_word_lists, test_tag_lists))
            ensemble_evaluate([crf_pred], test_tag_lists)
        elif y == 'bilstm':
            bilstm_word2id, bilstm_tag2id = extend_maps(word2id,
                                                        tag2id,
                                                        for_crf=False)
            lstm_pred = bilstm_train_and_eval(
                (train_word_lists, train_tag_lists),
                (dev_word_lists, dev_tag_lists),
                (test_word_lists, test_tag_lists),
                bilstm_word2id,
                bilstm_tag2id,
                crf=False)
            ensemble_evaluate([lstm_pred], test_tag_lists)

        elif y == 'bilstm-crf':
            crf_word2id, crf_tag2id = extend_maps(word2id,
                                                  tag2id,
                                                  for_crf=True)
            # more data processing
            train_word_lists, train_tag_lists = prepocess_data_for_lstmcrf(
                train_word_lists, train_tag_lists)
            dev_word_lists, dev_tag_lists = prepocess_data_for_lstmcrf(
                dev_word_lists, dev_tag_lists)
            test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
                test_word_lists, test_tag_lists, test=True)
            lstmcrf_pred = bilstm_train_and_eval(
                (train_word_lists, train_tag_lists),
                (dev_word_lists, dev_tag_lists),
                (test_word_lists, test_tag_lists), crf_word2id, crf_tag2id)
            ensemble_evaluate([lstmcrf_pred], test_tag_lists)

    else:

        HMM_MODEL_PATH = './ckpts/hmm.pkl'
        CRF_MODEL_PATH = './ckpts/crf.pkl'
        BiLSTM_MODEL_PATH = './ckpts/bilstm.pkl'
        BiLSTMCRF_MODEL_PATH = './ckpts/bilstm_crf.pkl'

        REMOVE_O = False  # Whether to remove the O mark at the time of evaluation

        # select data according to args.process
        print("Read data...")
        train_word_lists, train_tag_lists, word2id, tag2id = \
            build_corpus("train")
        dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
        test_word_lists, test_tag_lists = build_corpus("test",
                                                       make_vocab=False)

        if y == 'crf':
            crf_model = load_model_1(CRF_MODEL_PATH)
            crf_pred = crf_model.test(test_word_lists)
            metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
            metrics.report_scores()
            metrics.report_confusion_matrix()

        elif y == 'bilstm':
            bilstm_word2id, bilstm_tag2id = extend_maps(word2id,
                                                        tag2id,
                                                        for_crf=False)
            bilstm_model = load_model_1(BiLSTM_MODEL_PATH)
            bilstm_model.model.bilstm.flatten_parameters()  # remove warning
            lstm_pred, target_tag_list = bilstm_model.test(
                test_word_lists, test_tag_lists, bilstm_word2id, bilstm_tag2id)
            metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
            metrics.report_scores()
            metrics.report_confusion_matrix()

        elif y == 'bilstm-crf':
            crf_word2id, crf_tag2id = extend_maps(word2id,
                                                  tag2id,
                                                  for_crf=True)
            bilstm_model = load_model_1(BiLSTMCRF_MODEL_PATH)
            bilstm_model.model.bilstm.bilstm.flatten_parameters(
            )  # remove warning
            test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
                test_word_lists, test_tag_lists, test=True)
            lstmcrf_pred, target_tag_list = bilstm_model.test(
                test_word_lists, test_tag_lists, crf_word2id, crf_tag2id)
            metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
            metrics.report_scores()
            metrics.report_confusion_matrix()

    exit()
示例#15
0
文件: work.py 项目: darr/dlner
def _print_metrics(tag_lists, pred):
    REMOVE_O = False  # 在评估的时候是否去除O标记
    metrics = Metrics(tag_lists, pred, remove_O=REMOVE_O)
    metrics.report_scores()  # 打印每个标记的精确度、召回率、f1分数
    metrics.report_confusion_matrix()  # 打印混淆矩阵
def main():
    import argparse
    parser = argparse.ArgumentParser(description='main.py')
    parser.add_argument('--hmm',
                        action='store_true',
                        default=False,
                        help='Test HMM')
    parser.add_argument('--crf',
                        action='store_true',
                        default=False,
                        help='Test CRF')
    parser.add_argument('--bilstm',
                        action='store_true',
                        default=False,
                        help='Test BiLSTM')
    parser.add_argument('--bilstm-crf',
                        action='store_true',
                        default=False,
                        help='Test BiLSTM-CRF')
    parser.add_argument('--cbow',
                        action='store_true',
                        default=False,
                        help='Use CBOW embedding for BiLSTM-CRF')
    args = parser.parse_args()

    print("读取数据...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

    if args.hmm:
        print("加载并评估hmm模型...")
        hmm_model = load_model(HMM_MODEL_PATH)
        hmm_pred = hmm_model.test(test_word_lists, word2id, tag2id)
        metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
        metrics.report_scores()  # 打印每个标记的精确度、召回率、f1分数
        metrics.report_confusion_matrix()  # 打印混淆矩阵

    # 加载并评估CRF模型
    if args.crf:
        print("加载并评估crf模型...")
        crf_model = load_model(CRF_MODEL_PATH)
        crf_pred = crf_model.test(test_word_lists)
        metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
        metrics.report_scores()
        metrics.report_confusion_matrix()

    # bilstm模型
    if args.bilstm:
        print("加载并评估bilstm模型...")
        bilstm_word2id, bilstm_tag2id = extend_maps(word2id,
                                                    tag2id,
                                                    for_crf=False)
        bilstm_model = load_model(BiLSTM_MODEL_PATH)
        bilstm_model.model.bilstm.flatten_parameters()  # remove warning
        lstm_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                       test_tag_lists,
                                                       bilstm_word2id,
                                                       bilstm_tag2id)
        metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
        metrics.report_scores()
        metrics.report_confusion_matrix()

    if args.bilstm_crf:
        print("加载并评估bilstm+crf模型...")
        crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
        bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
        bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
        test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
            test_word_lists, test_tag_lists, test=True)
        lstmcrf_pred, target_tag_list = bilstm_model.test(
            test_word_lists, test_tag_lists, crf_word2id, crf_tag2id)
        metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
        metrics.report_scores()
        metrics.report_confusion_matrix()
示例#17
0
def main():
    print("读取数据...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

    # print("加载并评估hmm模型...")
    # hmm_model = load_model(HMM_MODEL_PATH)
    # hmm_pred = hmm_model.test(test_word_lists,
    #                           word2id,
    #                           tag2id)
    # metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
    # metrics.report_scores()  # 打印每个标记的精确度、召回率、f1分数
    # metrics.report_confusion_matrix()  # 打印混淆矩阵
    #
    # # 加载并评估CRF模型
    # print("加载并评估crf模型...")
    # crf_model = load_model(CRF_MODEL_PATH)
    # crf_pred = crf_model.test(test_word_lists)
    # metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
    # metrics.report_scores()
    # metrics.report_confusion_matrix()
    #
    # # bilstm模型
    # print("加载并评估bilstm模型...")
    # bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
    # bilstm_model = load_model(BiLSTM_MODEL_PATH)
    # bilstm_model.model.bilstm.flatten_parameters()  # remove warning
    # lstm_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
    #                                                bilstm_word2id, bilstm_tag2id)
    # metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
    # metrics.report_scores()
    # metrics.report_confusion_matrix()

    print("加载并评估bilstm+crf模型...")
    crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
    bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
    bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
    test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
        test_word_lists, test_tag_lists, test=True)

    lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                      test_tag_lists,
                                                      crf_word2id, crf_tag2id)
    metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    # ensemble_evaluate(
    #     [hmm_pred, crf_pred, lstm_pred, lstmcrf_pred],
    #     test_tag_lists
    # )

    ls = ['B-SYM', 'M-SYM', 'E-SYM']

    selected = [
        i for i in range(len(test_tag_lists[0])) if test_tag_lists[0][i] in ls
    ]
    selected_word = [test_word_lists[0][i] for i in selected]
    selected_predict = [
        i for i in range(len(lstmcrf_pred[0])) if lstmcrf_pred[0][i] in ls
    ]
    selected_predict_word = [test_word_lists[0][i] for i in selected_predict]

    for tag_list, doc in zip(train_tag_lists, train_word_lists):
        selected_train = [i for i in range(len(tag_list)) if tag_list[i] in ls]
        selected_train_word = [doc[i] for i in selected_train]
        # print(selected_train_word)

    print('preditct list:', lstmcrf_pred)
    print('target list:', target_tag_list)
    print(selected_word)
    print(selected_predict_word)