Пример #1
0
def eval_checkpoint(model_object, eval_dataloader, config, device, n_gpu, label_list, eval_sign="dev"):
    # input_dataloader type can only be one of dev_dataloader, test_dataloader 
    model_object.eval()

    eval_loss = 0 
    start_pred_lst = []
    end_pred_lst = []
    span_pred_lst = []
    mask_lst = []
    start_gold_lst = []
    span_gold_lst = []
    end_gold_lst = []
    eval_steps = 0 
    ner_cate_lst = [] 

    for input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, ner_cate in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        start_pos = start_pos.to(device)
        end_pos = end_pos.to(device)
        span_pos = span_pos.to(device) 

        with torch.no_grad():
            tmp_eval_loss = model_object(input_ids, segment_ids, input_mask, start_pos, end_pos, span_pos)
            start_logits, end_logits, span_logits = model_object(input_ids, segment_ids, input_mask)

        start_pos = start_pos.to("cpu").numpy().tolist()
        end_pos = end_pos.to("cpu").numpy().tolist()
        span_pos = span_pos.to("cpu").numpy().tolist()

        start_label = start_logits.detach().cpu().numpy().tolist()
        end_label = end_logits.detach().cpu().numpy().tolist()
        span_logits = span_logits.detach().cpu().numpy().tolist()
        span_label = span_logits
        
        input_mask = input_mask.to("cpu").detach().numpy().tolist()

        ner_cate_lst += ner_cate.numpy().tolist()
        eval_loss += tmp_eval_loss.mean().item()
        mask_lst += input_mask 
        eval_steps += 1

        start_pred_lst += start_label 
        end_pred_lst += end_label 
        span_pred_lst += span_label
        
        start_gold_lst += start_pos 
        end_gold_lst += end_pos 
        span_gold_lst += span_pos
    
    if config.entity_sign == "flat":
        eval_accuracy, eval_precision, eval_recall, eval_f1 = flat_ner_performance(start_pred_lst,
                                                                                   end_pred_lst,
                                                                                   span_pred_lst,
                                                                                   start_gold_lst,
                                                                                   end_gold_lst,
                                                                                   span_gold_lst,
                                                                                   ner_cate_lst,
                                                                                   label_list,
                                                                                   threshold=config.entity_threshold,
                                                                                   dims=2)
    else:
        eval_accuracy, eval_precision, eval_recall, eval_f1 = nested_ner_performance(start_pred_lst,
                                                                                     end_pred_lst,
                                                                                     span_pred_lst,
                                                                                     start_gold_lst,
                                                                                     end_gold_lst,
                                                                                     span_gold_lst,
                                                                                     ner_cate_lst,
                                                                                     label_list,
                                                                                     threshold=config.entity_threshold,
                                                                                     dims=2)

    average_loss = round(eval_loss / eval_steps, 4)  
    eval_f1 = round(eval_f1, 4)
    eval_precision = round(eval_precision, 4)
    eval_recall = round(eval_recall, 4)
    eval_accuracy = round(eval_accuracy, 4)

    return average_loss, eval_accuracy, eval_precision, eval_recall, eval_f1 
Пример #2
0
def eval_checkpoint(model_object, eval_dataloader, config, \
    device, n_gpu, label_list, eval_sign="test"):

    model_object.eval()

    start_pred_lst = []
    end_pred_lst = []
    span_pred_lst = []
    mask_lst = []

    start_gold_lst = []
    end_gold_lst = []
    span_gold_lst = []

    ner_cate_lst = []

    for input_ids, input_mask, segment_ids, start_pos, end_pos, ner_cate in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        start_pos = start_pos.to(device)
        end_pos = end_pos.to(device)
        # span_pos = span_pos.to(device)

        with torch.no_grad():
            start_logits, end_logits = model_object(input_ids, segment_ids,
                                                    input_mask)

        start_pos = start_pos.to("cpu").numpy().tolist()
        end_pos = end_pos.to("cpu").numpy().tolist()
        # span_pos = span_pos.to("cpu").numpy().tolist()

        start_label = start_logits.detach().cpu().numpy().tolist()
        end_label = end_logits.detach().cpu().numpy().tolist()
        # span_label = span_logits.detach().cpu().numpy().tolist()

        input_mask = input_mask.to("cpu").detach().numpy().tolist()

        ner_cate_lst += ner_cate.numpy().tolist()
        mask_lst += input_mask

        start_pred_lst += start_label
        end_pred_lst += end_label
        # span_pred_lst += span_label

        start_gold_lst += start_pos
        end_gold_lst += end_pos
        # span_gold_lst += span_pos

    span_pred_lst = [[[1] * len(start_gold_lst[0])] * len(start_gold_lst[0])
                     ] * len(start_gold_lst)
    span_gold_lst = [[[1] * len(start_gold_lst[0])] * len(start_gold_lst[0])
                     ] * len(start_gold_lst)

    if config.entity_sign == "flat":
        acc, pre, rec, f1, pred_span_triple_lst = flat_ner_performance(
            start_pred_lst,
            end_pred_lst,
            span_pred_lst,
            start_gold_lst,
            end_gold_lst,
            span_gold_lst,
            ner_cate_lst,
            label_list,
            threshold=config.entity_threshold,
            dims=2)
    else:
        acc, pre, rec, f1, pred_span_triple_lst = nested_ner_performance(
            start_pred_lst,
            end_pred_lst,
            span_pred_lst,
            start_gold_lst,
            end_gold_lst,
            span_gold_lst,
            ner_cate_lst,
            label_list,
            threshold=config.entity_threshold,
            dims=2)

    pred_entity_list = []
    # 根据预测的序列标签得到
    if config.entity_sign == "flat":
        for tags in pred_span_triple_lst:
            entity_list = []
            entity = {}
            index = 0
            while index < len(tags):
                tag = tags[index]
                if tag != 'O':
                    entity['begin'] = index
                    start_pos, type = tag.split('-')
                    while tag != 'O':
                        index += 1
                        tag = tags[index]
                    end_pos, type = tags[index - 1].split('-')
                    if start_pos == 'B' and end_pos == 'E':
                        entity['end'] = index
                        entity['tag'] = type
                        entity_list.append(entity)
                        entity = {}
                index += 1
            pred_entity_list.append(entity_list)

    # 组织答案
    span_triple_lst = [[] for i in range(len(pred_span_triple_lst))]
    for i, j in zip(pred_entity_list, span_triple_lst):
        for z in i:
            j.append({'tag': z['tag'], 'begin': z['begin'], 'end': z['end']})

    with open('result.json', 'w', encoding='utf-8') as file_obj:
        json.dump(span_triple_lst, file_obj, ensure_ascii=False)

    print('保存成功')
    print("=*=" * 10)
    print("eval: acc, pre, rec, f1")
    print(acc, pre, rec, f1)

    return acc, pre, rec, f1
Пример #3
0
def eval_checkpoint(model_object, eval_dataloader, config, \
    device, n_gpu, label_list, eval_sign="test"):

    model_object.eval()

    start_pred_lst = []
    end_pred_lst = []
    span_pred_lst = []
    mask_lst = []

    start_gold_lst = []
    end_gold_lst = []
    span_gold_lst = []

    ner_cate_lst = []

    for input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, ner_cate in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        start_pos = start_pos.to(device)
        end_pos = end_pos.to(device)
        span_pos = span_pos.to(device)

        with torch.no_grad():
            start_logits, end_logits, span_logits = model_object(
                input_ids, segment_ids, input_mask)

        start_pos = start_pos.to("cpu").numpy().tolist()
        end_pos = end_pos.to("cpu").numpy().tolist()
        span_pos = span_pos.to("cpu").numpy().tolist()

        start_label = start_logits.detach().cpu().numpy().tolist()
        end_label = end_logits.detach().cpu().numpy().tolist()
        span_label = span_logits.detach().cpu().numpy().tolist()

        input_mask = input_mask.to("cpu").detach().numpy().tolist()

        ner_cate_lst += ner_cate.numpy().tolist()
        mask_lst += input_mask

        start_pred_lst += start_label
        end_pred_lst += end_label
        span_pred_lst += span_label

        start_gold_lst += start_pos
        end_gold_lst += end_pos
        span_gold_lst += span_pos

    if config.entity_sign == "flat":
        acc, pre, rec, f1 = flat_ner_performance(
            start_pred_lst,
            end_pred_lst,
            span_pred_lst,
            start_gold_lst,
            end_gold_lst,
            span_gold_lst,
            ner_cate_lst,
            label_list,
            threshold=config.entity_threshold,
            dims=2)
    else:
        acc, pre, rec, f1 = nested_ner_performance(
            start_pred_lst,
            end_pred_lst,
            span_pred_lst,
            start_gold_lst,
            end_gold_lst,
            span_gold_lst,
            ner_cate_lst,
            label_list,
            threshold=config.entity_threshold,
            dims=2)

    print("=*=" * 10)
    print("eval: acc, pre, rec, f1")
    print(acc, pre, rec, f1)

    return acc, pre, rec, f1
def eval_checkpoint(model_object, eval_dataloader, config, \
    device, n_gpu, label_list, logger, eval_sign="test"):

    logger.info("$="*20)
    logger.info(f"EVAL {config.saved_model} on Test Set. ")
    model_object.eval()

    start_pred_lst = []
    end_pred_lst = []
    span_pred_lst = []
    mask_lst = []

    start_gold_lst = []
    end_gold_lst = []
    span_gold_lst = []

    ner_cate_lst = []


    for eval_idx, eval_batch in enumerate(eval_dataloader):
        input_ids, input_mask, segment_ids, start_pos, end_pos, span_pos, span_label_mask, ner_cate = eval_batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        start_pos = start_pos.to(device)
        end_pos = end_pos.to(device)
        span_pos = span_pos.to(device)

        with torch.no_grad():
            start_logits, end_logits, span_logits = model_object(input_ids, segment_ids, input_mask)

        start_pos = start_pos.to("cpu").numpy().tolist()
        end_pos = end_pos.to("cpu").numpy().tolist()
        span_pos = span_pos.to("cpu").numpy().tolist()

        start_label = start_logits.detach().cpu().numpy().tolist()
        end_label = end_logits.detach().cpu().numpy().tolist()
        span_label = span_logits.detach().cpu().numpy().tolist()

        input_mask = input_mask.to("cpu").detach().numpy().tolist()

        ner_cate_lst += ner_cate.numpy().tolist()
        mask_lst += input_mask

        start_pred_lst += start_label
        end_pred_lst += end_label
        span_pred_lst += span_label

        start_gold_lst += start_pos
        end_gold_lst += end_pos
        span_gold_lst += span_pos


    if config.entity_sign == "flat":
        acc, pre, rec, f1 = flat_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2)
    else:
        acc, pre, rec, f1 = nested_ner_performance(start_pred_lst, end_pred_lst, span_pred_lst, start_gold_lst, end_gold_lst, span_gold_lst, ner_cate_lst, label_list, threshold=config.entity_threshold, dims=2)


    logger.info("=*="*10)
    logger.info("eval on test set : acc, pre, rec, f1")
    logger.info(f"{acc}, {pre}, {rec}, {f1}")

    return acc, pre, rec, f1