def crf_pred(train_word_lists, train_tag_lists, test_word_lists,
             test_tag_lists):
    model = CRFModel()
    model.train(train_word_lists, train_tag_lists)
    save_model(model, "./ckpts/crf.pkl")
    print(test_word_lists)
    pred = model.test(test_word_lists)
    return pred
def crf_train_eval(train_data, test_word_lists, remove_O=False):

    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, "./ckpts/crf.pkl")

    pred_tag_lists = crf_model.test(test_word_lists)

    return pred_tag_lists
def crf_train_eval(train_data, test_data, remove_O=False):

    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    model_file = "./ckpts/crf.pkl"
    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, model_file)
    # crf_model = load_model(model_file)

    pred_tag_lists = crf_model.test(test_word_lists)
    results_print(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    return pred_tag_lists
예제 #4
0
파일: evaluate.py 프로젝트: lg995745318/-
def crf_train_eval(train_data, test_data):

    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, "./ckpts/crf.pkl")

    pred_tag_lists = crf_model.test(test_word_lists)

    metrics = Metrics(test_tag_lists, pred_tag_lists)

    return pred_tag_lists
def crf_train_eval(train_data, test_data):

    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, "./ckpts/crf.pkl")

    pred_tag_lists = crf_model.test(test_word_lists)
    accuracy = evaluate(pred_tag_lists, test_tag_lists)
    print("CRF 模型的准确率为:{:.2f}%".format(accuracy * 100))

    return pred_tag_lists
def crf_train_eval(train_data, test_data, remove_O=False):
    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data

    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, "./ckpts/crf.pkl")

    pred_tag_lists = crf_model.test(test_word_lists)

    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()

    return pred_tag_lists
예제 #7
0
def main():

    print('读取数据...')
    train_word_lists, train_tag_lists, word2id, tag2id = build_corpus('train')
    dev_word_lists, dev_tag_lists = build_corpus('dev', maek_vocab = False)
    test_word_lists, test_tag_lists = build_corpus('test', maek_vocab = False)

    print('训练HMM模型...')
    hmm_model = HMMModel(len(tag2id), len(word2id))
    hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
    pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)

    metrics = Metrics(test_tag_lists, pred_tag_lists)
    metrics.report_scores()

    print('训练CRF模型...')
    crf_model = CRFModel(max_iterations = 90)
    crf_model.train(train_word_lists, train_tag_lists)
    pred_tag_lists = crf_model.test(test_word_lists)

    metrics = Metrics(test_tag_lists, pred_tag_lists)
    metrics.report_scores()
    
    
    print('训练BiLSTM模型...')
    word2id, tag2id = extend_maps(word2id, tag2id)
    bilstm = BiLSTM(len(word2id), len(tag2id))
    bilstm.train(train_word_lists, train_tag_lists, dev_word_lists, dev_tag_lists, word2id, tag2id, 0.8)
    bilstm.dev_test(test_word_lists, test_tag_lists, word2id, tag2id)
    bilstm.close_sess()
    

    print('训练BiLSTM-CRF模型...')
    bilstm_crf = BiLSTM_CRF(len(word2id), len(tag2id))
    bilstm_crf.train(train_word_lists, train_tag_lists, dev_word_lists, dev_tag_lists, word2id, tag2id, 0.8)
    bilstm_crf.dev_test(test_word_lists, test_tag_lists, word2id, tag2id)
    bilstm_crf.close_sess()
예제 #8
0
파일: work.py 프로젝트: darr/dlner
def crf_train(train_data):
    train_word_lists, train_tag_lists = train_data
    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    save_model(crf_model, CRF_MODEL_PATH)