Пример #1
0
def ensemble_evaluate(hmm_pred,
                      crf_pred,
                      lstm_pred,
                      lstmcrf_pred,
                      latticelstm_pred,
                      targets,
                      status='train'):
    """ensemble多个模型"""
    tag2id1 = {
        '<start>': 0,
        'O': 1,
        'B-ATTRIBUTE': 2,
        'M-ATTRIBUTE': 3,
        'E-ATTRIBUTE': 4,
        'B-OBJECT': 5,
        'M-OBJECT': 6,
        'E-OBJECT': 7,
        'B-CONDITION': 8,
        'M-CONDITION': 9,
        'E-CONDITION': 10,
        'B-PARAMETERS': 11,
        'M-PARAMETERS': 12,
        'E-PARAMETERS': 13,
        'S-ATTRIBUTE': 14,
        'S-OBJECT': 15,
        'S-CONDITION': 16,
        'S-PARAMETERS': 17,
        '<end>': 18
    }
    length = len(tag2id1)
    transition = np.loadtxt(open(
        r"C:\Users\DELL\PycharmProjects\research\transition.csv", "rb"),
                            delimiter="\t",
                            skiprows=1)

    hmm_pred = append_start_end(hmm_pred)
    crf_pred = append_start_end(crf_pred)
    lstm_pred = append_start_end(lstm_pred)
    lstmcrf_pred = append_start_end(lstmcrf_pred)
    latticelstm_pred = append_start_end(latticelstm_pred)

    pred_score = []
    for i in zip(hmm_pred, crf_pred, lstm_pred, lstmcrf_pred,
                 latticelstm_pred):
        t = []
        for j in zip(zip(*i)):
            score = np.zeros(length)
            for s in j[0]:
                score[tag2id1[s]] += 0.2
            t.append(score)
        pred_score.append(t)

    # 加入约束
    pred_tags = []
    previous = None
    reverse_tag2id = dict([(index, word) for (word, index) in tag2id1.items()])

    for p in pred_score:
        te = []
        for r in p:
            if previous is None:
                if np.argmax(r) != 0 and np.argmax(r) != 18:
                    te.append(reverse_tag2id[np.argmax(r)])
                    previous = np.argmax(r)
            else:
                pre = np.array(transition[previous])
                r += 0.5 * pre
                if np.argmax(r) != 0 and np.argmax(r) != 18:
                    te.append(reverse_tag2id[np.argmax(r)])
                    previous = np.argmax(r)
        pred_tags.append(te)

    assert len(pred_tags) == len(targets)

    metrics = Metrics(targets, pred_tags)

    return pred_tags
def main():
    print("Read data...")
    train_word_lists, train_tag_lists, word2id, tag2id = \
        build_corpus("train")
    dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
    test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

    print("Load and evaluate the hmm model...")
    hmm_model = load_model(HMM_MODEL_PATH)
    hmm_pred = hmm_model.test(test_word_lists, word2id, tag2id)
    metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O)
    metrics.report_scores(
    )  # Print the accuracy of each mark, recall rate, f1 score
    metrics.report_confusion_matrix()  #Print confusion matrix

    # Load and evaluate the CRF model
    print("Load and evaluate the crf model...")
    crf_model = load_model(CRF_MODEL_PATH)
    crf_pred = crf_model.test(test_word_lists)
    metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    # bilstm Model
    print("Load and evaluate the bilstm model...")
    bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
    bilstm_model = load_model(BiLSTM_MODEL_PATH)
    bilstm_model.model.bilstm.flatten_parameters()  # remove warning
    lstm_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                   test_tag_lists,
                                                   bilstm_word2id,
                                                   bilstm_tag2id)
    metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    print("Load and evaluate the bilstm+crf model...")
    crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
    bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
    bilstm_model.model.bilstm.bilstm.flatten_parameters()  # remove warning
    test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
        test_word_lists, test_tag_lists, test=True)
    lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists,
                                                      test_tag_lists,
                                                      crf_word2id, crf_tag2id)
    metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    ensemble_evaluate([hmm_pred, crf_pred, lstm_pred, lstmcrf_pred],
                      test_tag_lists)