def pred(model, best_model_path, result_conll_path, fname="test.csv"): learner = NerLearner(model, data, best_model_path=model_dir + "/norne/bilstm_attn_cased_en.cpt", lr=0.01, clip=1.0, sup_labels=[ l for l in data.id2label if l not in ['<pad>', '[CLS]', 'X', 'B_O', 'I_'] ]) dl = get_bert_data_loader_for_predict(data_path + fname, learner) learner.load_model(best_model_path) preds = learner.predict(dl) tokens, y_true, y_pred, set_labels = bert_preds_to_ys(dl, preds) clf_report = flat_classification_report(y_true, y_pred, set_labels, digits=3) # clf_report = get_bert_span_report(dl, preds) print(clf_report) write_true_and_pred_to_conll(tokens=tokens, y_true=y_true, y_pred=y_pred, conll_fpath=result_conll_path)
def train(model, num_epochs=20): learner = NerLearner(model, data, best_model_path=model_dir + "/norne/bilstm_attn_lr0_1_cased_en.cpt", lr=0.1, clip=1.0, sup_labels=[ l for l in data.id2label if l not in ['<pad>', '[CLS]', 'X', 'B_O', 'I_'] ], t_total=num_epochs * len(data.train_dl)) learner.fit(num_epochs, target_metric='f1') dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner) learner.load_model() preds = learner.predict(dl) print( validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels)) clf_report = get_bert_span_report(dl, preds, []) print(clf_report)
data_path = "/media/liah/DATA/ner_data_other/norne/" train_path = data_path + "train.txt" dev_path = data_path + "valid.txt" test_path = data_path + "test.txt" dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner) model = BertBiLSTMAttnNMT.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=128, dec_hidden_dim=128, dec_embedding_dim=16) learner = NerLearner(model, data, best_model_path=model_dir + "conll-2003/bilstm_attn_cased.cpt", lr=0.01, clip=1.0, sup_labels=[ l for l in data.id2label if l not in ['<pad>', '[CLS]', 'X', 'B_O', 'I_'] ], t_total=num_epochs * len(data.train_dl)) learner.load_model(best_model_path) preds = learner.predict(dl)