Exemplo n.º 1
0
 def predict_one(self, sentence):
     """
     对输入的句子进行ner识别,取batch中的第一行结果
     :param sentence:
     :return:
     """
     if self.configs.use_bert:
         X, y, att_mask, Sentence = self.dataManager.prepare_single_sentence(sentence)
         if self.configs.finetune:
             model_inputs = (X, att_mask)
         else:
             model_inputs = self.bert_model(X, attention_mask=att_mask)[0]
     else:
         X, y, Sentence = self.dataManager.prepare_single_sentence(sentence)
         model_inputs = X
     inputs_length = tf.math.count_nonzero(X, 1)
     logits, log_likelihood, transition_params = self.ner_model(
             inputs=model_inputs, inputs_length=inputs_length, targets=y)
     label_predicts, _ = crf_decode(logits, transition_params, inputs_length)
     label_predicts = label_predicts.numpy()
     sentence = Sentence[0, 0:inputs_length[0]]
     y_pred = [str(self.dataManager.id2label[val]) for val in label_predicts[0][0:inputs_length[0]]]
     if self.configs.use_bert:
         # 去掉[CLS]和[SEP]对应的位置
         y_pred = y_pred[1:-1]
     entities, suffixes, indices = extract_entity(sentence, y_pred, self.dataManager)
     return entities, suffixes, indices
Exemplo n.º 2
0
def metrics(X, y_true, y_pred, configs, data_manager, tokenizer):
    precision = -1.0
    recall = -1.0
    f1 = -1.0

    hit_num = 0
    pred_num = 0
    true_num = 0

    correct_label_num = 0
    total_label_num = 0

    label_num = {}
    label_metrics = {}
    measuring_metrics = configs.measuring_metrics
    # tensor向量不能直接索引,需要转成numpy
    y_pred = y_pred.numpy()
    for i in range(len(y_true)):
        if configs.use_bert:
            x = tokenizer.convert_ids_to_tokens(X[i].tolist(),
                                                skip_special_tokens=True)
        else:
            x = [
                str(data_manager.id2token[val]) for val in X[i]
                if val != data_manager.token2id[data_manager.PADDING]
            ]
        y = [
            str(data_manager.id2label[val]) for val in y_true[i]
            if val != data_manager.label2id[data_manager.PADDING]
        ]
        y_hat = [
            str(data_manager.id2label[val]) for val in y_pred[i]
            if val != data_manager.label2id[data_manager.PADDING]
        ]  # if val != 5

        correct_label_num += len([1 for a, b in zip(y, y_hat) if a == b])
        total_label_num += len(y)

        true_labels, labeled_labels_true, _ = extract_entity(
            x, y, data_manager)
        pred_labels, labeled_labels_pred, _ = extract_entity(
            x, y_hat, data_manager)

        hit_num += len(set(true_labels) & set(pred_labels))
        pred_num += len(set(pred_labels))
        true_num += len(set(true_labels))

        for label in data_manager.suffix:
            label_num.setdefault(label, {})
            label_num[label].setdefault('hit_num', 0)
            label_num[label].setdefault('pred_num', 0)
            label_num[label].setdefault('true_num', 0)

            true_lab = [
                x for (x, y) in zip(true_labels, labeled_labels_true)
                if y == label
            ]
            pred_lab = [
                x for (x, y) in zip(pred_labels, labeled_labels_pred)
                if y == label
            ]

            label_num[label]['hit_num'] += len(set(true_lab) & set(pred_lab))
            label_num[label]['pred_num'] += len(set(pred_lab))
            label_num[label]['true_num'] += len(set(true_lab))

    if total_label_num != 0:
        accuracy = 1.0 * correct_label_num / total_label_num

    if pred_num != 0:
        precision = 1.0 * hit_num / pred_num
    if true_num != 0:
        recall = 1.0 * hit_num / true_num
    if precision > 0 and recall > 0:
        f1 = 2.0 * (precision * recall) / (precision + recall)

    # 按照字段切分
    for label in label_num.keys():
        tmp_precision = 0
        tmp_recall = 0
        tmp_f1 = 0
        # 只包括BI
        if label_num[label]['pred_num'] != 0:
            tmp_precision = 1.0 * label_num[label]['hit_num'] / label_num[
                label]['pred_num']
        if label_num[label]['true_num'] != 0:
            tmp_recall = 1.0 * label_num[label]['hit_num'] / label_num[label][
                'true_num']
        if tmp_precision > 0 and tmp_recall > 0:
            tmp_f1 = 2.0 * (tmp_precision * tmp_recall) / (tmp_precision +
                                                           tmp_recall)
        label_metrics.setdefault(label, {})
        label_metrics[label]['precision'] = tmp_precision
        label_metrics[label]['recall'] = tmp_recall
        label_metrics[label]['f1'] = tmp_f1

    results = {}
    for measure in measuring_metrics:
        results[measure] = vars()[measure]
    return results, label_metrics