def eval_checkpoint(model_object, eval_dataloader, config, device, n_gpu, label_list, eval_sign="dev"): # input_dataloader type can only be one of dev_dataloader, test_dataloader model_object.eval() eval_loss = 0 start_pred_lst = [] end_pred_lst = [] span_pred_lst = [] mask_lst = [] start_gold_lst = [] span_gold_lst = [] end_gold_lst = [] eval_steps = 0 ner_cate_lst = [] for input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, ner_cate in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) start_pos = start_pos.to(device) end_pos = end_pos.to(device) span_pos = span_pos.to(device) with torch.no_grad(): tmp_eval_loss = model_object(input_ids, segment_ids, input_mask, start_pos, end_pos, span_pos) start_logits, end_logits, span_logits = model_object(input_ids, segment_ids, input_mask) start_pos = start_pos.to("cpu").numpy().tolist() end_pos = end_pos.to("cpu").numpy().tolist() span_pos = span_pos.to("cpu").numpy().tolist() start_label = start_logits.detach().cpu().numpy().tolist() end_label = end_logits.detach().cpu().numpy().tolist() span_logits = span_logits.detach().cpu().numpy().tolist() span_label = span_logits input_mask = input_mask.to("cpu").detach().numpy().tolist() ner_cate_lst += ner_cate.numpy().tolist() eval_loss += tmp_eval_loss.mean().item() mask_lst += input_mask eval_steps += 1 start_pred_lst += start_label end_pred_lst += end_label span_pred_lst += span_label start_gold_lst += start_pos end_gold_lst += end_pos span_gold_lst += span_pos if config.entity_sign == "flat": eval_accuracy, eval_precision, eval_recall, eval_f1 = flat_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) else: eval_accuracy, eval_precision, eval_recall, eval_f1 = nested_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) average_loss = round(eval_loss / eval_steps, 4) eval_f1 = round(eval_f1, 4) eval_precision = round(eval_precision, 4) eval_recall = round(eval_recall, 4) eval_accuracy = round(eval_accuracy, 4) return average_loss, eval_accuracy, eval_precision, eval_recall, eval_f1
def eval_checkpoint(model_object, eval_dataloader, config, \ device, n_gpu, label_list, eval_sign="test"): model_object.eval() start_pred_lst = [] end_pred_lst = [] span_pred_lst = [] mask_lst = [] start_gold_lst = [] end_gold_lst = [] span_gold_lst = [] ner_cate_lst = [] for input_ids, input_mask, segment_ids, start_pos, end_pos, ner_cate in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) start_pos = start_pos.to(device) end_pos = end_pos.to(device) # span_pos = span_pos.to(device) with torch.no_grad(): start_logits, end_logits = model_object(input_ids, segment_ids, input_mask) start_pos = start_pos.to("cpu").numpy().tolist() end_pos = end_pos.to("cpu").numpy().tolist() # span_pos = span_pos.to("cpu").numpy().tolist() start_label = start_logits.detach().cpu().numpy().tolist() end_label = end_logits.detach().cpu().numpy().tolist() # span_label = span_logits.detach().cpu().numpy().tolist() input_mask = input_mask.to("cpu").detach().numpy().tolist() ner_cate_lst += ner_cate.numpy().tolist() mask_lst += input_mask start_pred_lst += start_label end_pred_lst += end_label # span_pred_lst += span_label start_gold_lst += start_pos end_gold_lst += end_pos # span_gold_lst += span_pos span_pred_lst = [[[1] * len(start_gold_lst[0])] * len(start_gold_lst[0]) ] * len(start_gold_lst) span_gold_lst = [[[1] * len(start_gold_lst[0])] * len(start_gold_lst[0]) ] * len(start_gold_lst) if config.entity_sign == "flat": acc, pre, rec, f1, pred_span_triple_lst = flat_ner_performance( start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) else: acc, pre, rec, f1, pred_span_triple_lst = nested_ner_performance( start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) pred_entity_list = [] # 根据预测的序列标签得到 if config.entity_sign == "flat": for tags in pred_span_triple_lst: entity_list = [] entity = {} index = 0 while index < len(tags): tag = tags[index] if tag != 'O': entity['begin'] = index start_pos, type = tag.split('-') while tag != 'O': index += 1 tag = tags[index] end_pos, type = tags[index - 1].split('-') if start_pos == 'B' and end_pos == 'E': entity['end'] = index entity['tag'] = type entity_list.append(entity) entity = {} index += 1 pred_entity_list.append(entity_list) # 组织答案 span_triple_lst = [[] for i in range(len(pred_span_triple_lst))] for i, j in zip(pred_entity_list, span_triple_lst): for z in i: j.append({'tag': z['tag'], 'begin': z['begin'], 'end': z['end']}) with open('result.json', 'w', encoding='utf-8') as file_obj: json.dump(span_triple_lst, file_obj, ensure_ascii=False) print('保存成功') print("=*=" * 10) print("eval: acc, pre, rec, f1") print(acc, pre, rec, f1) return acc, pre, rec, f1
def eval_checkpoint(model_object, eval_dataloader, config, \ device, n_gpu, label_list, eval_sign="test"): model_object.eval() start_pred_lst = [] end_pred_lst = [] span_pred_lst = [] mask_lst = [] start_gold_lst = [] end_gold_lst = [] span_gold_lst = [] ner_cate_lst = [] for input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, ner_cate in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) start_pos = start_pos.to(device) end_pos = end_pos.to(device) span_pos = span_pos.to(device) with torch.no_grad(): start_logits, end_logits, span_logits = model_object( input_ids, segment_ids, input_mask) start_pos = start_pos.to("cpu").numpy().tolist() end_pos = end_pos.to("cpu").numpy().tolist() span_pos = span_pos.to("cpu").numpy().tolist() start_label = start_logits.detach().cpu().numpy().tolist() end_label = end_logits.detach().cpu().numpy().tolist() span_label = span_logits.detach().cpu().numpy().tolist() input_mask = input_mask.to("cpu").detach().numpy().tolist() ner_cate_lst += ner_cate.numpy().tolist() mask_lst += input_mask start_pred_lst += start_label end_pred_lst += end_label span_pred_lst += span_label start_gold_lst += start_pos end_gold_lst += end_pos span_gold_lst += span_pos if config.entity_sign == "flat": acc, pre, rec, f1 = flat_ner_performance( start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) else: acc, pre, rec, f1 = nested_ner_performance( start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) print("=*=" * 10) print("eval: acc, pre, rec, f1") print(acc, pre, rec, f1) return acc, pre, rec, f1
def eval_checkpoint(model_object, eval_dataloader, config, \ device, n_gpu, label_list, logger, eval_sign="test"): logger.info("$="*20) logger.info(f"EVAL {config.saved_model} on Test Set. ") model_object.eval() start_pred_lst = [] end_pred_lst = [] span_pred_lst = [] mask_lst = [] start_gold_lst = [] end_gold_lst = [] span_gold_lst = [] ner_cate_lst = [] for eval_idx, eval_batch in enumerate(eval_dataloader): input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, span_label_mask, ner_cate = eval_batch input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) start_pos = start_pos.to(device) end_pos = end_pos.to(device) span_pos = span_pos.to(device) with torch.no_grad(): start_logits, end_logits, span_logits = model_object(input_ids, segment_ids, input_mask) start_pos = start_pos.to("cpu").numpy().tolist() end_pos = end_pos.to("cpu").numpy().tolist() span_pos = span_pos.to("cpu").numpy().tolist() start_label = start_logits.detach().cpu().numpy().tolist() end_label = end_logits.detach().cpu().numpy().tolist() span_label = span_logits.detach().cpu().numpy().tolist() input_mask = input_mask.to("cpu").detach().numpy().tolist() ner_cate_lst += ner_cate.numpy().tolist() mask_lst += input_mask start_pred_lst += start_label end_pred_lst += end_label span_pred_lst += span_label start_gold_lst += start_pos end_gold_lst += end_pos span_gold_lst += span_pos if config.entity_sign == "flat": acc, pre, rec, f1 = flat_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) else: acc, pre, rec, f1 = nested_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2) logger.info("=*="*10) logger.info("eval on test set : acc, pre, rec, f1") logger.info(f"{acc}, {pre}, {rec}, {f1}") return acc, pre, rec, f1