def label_to_pred(labels):
        """Convert a list of gold human annotations to a perfect prediction."""
        gold_has_short_answer = util.gold_has_short_answer(labels)

        gold_has_long_answer = util.gold_has_long_answer(labels)

        # We did not put `long_answer` and `yes_no_answer`, and they should be
        # considered as null when loading from input.

        pred = {
            'example_id': labels[0].example_id,
            'short_answers': [],
            'short_answers_score': random.random(),
            'long_answer_score': random.random()
        }

        keep_answer = random.random() <= FLAGS.desired_recall
        for label in labels:
            if gold_has_short_answer and keep_answer:
                pred['short_answers_score'] *= 2
                if not util.is_null_span_list(label.short_answer_span_list):
                    pred['short_answers'] = ([{
                        'start_token': span.start_token_idx,
                        'end_token': span.end_token_idx,
                        'start_byte': span.start_byte,
                        'end_byte': span.end_byte
                    } for span in label.short_answer_span_list])
                    pred['yes_no_answer'] = 'none'
                elif label.yes_no_answer != 'none':
                    pred['short_answers'] = []
                    pred['yes_no_answer'] = label.yes_no_answer

            if (gold_has_long_answer
                    and not label.long_answer_span.is_null_span()
                    and keep_answer):
                pred['long_answer'] = {
                    'start_token': label.long_answer_span.start_token_idx,
                    'end_token': label.long_answer_span.end_token_idx,
                    'start_byte': label.long_answer_span.start_byte,
                    'end_byte': label.long_answer_span.end_byte
                }
                pred['long_answer_score'] *= 2

        if FLAGS.generate_false_positives:
            if not gold_has_short_answer:
                pred['short_answers'] = [{
                    'start_token': 0,
                    'end_token': 1,
                    'start_byte': -1,
                    'end_byte': -1
                }]

            if not gold_has_long_answer:
                pred['long_answer_start_token'] = 0
                pred['long_answer_end_token'] = 1

        return pred
def score_short_answer(gold_label_list, pred_label, score_thres):
    """Scores a short answer as correct or not.

    1) First decide if there is a gold short answer with SHORT_NO_NULL_THRESHOLD.
    2) The prediction will get a match if:
       a. There is a gold short answer.
       b. The prediction span *set* match exactly with *one* of the non-null gold
          short answer span *set*.

    Args:
      gold_label_list: A list of NQLabel.
      pred_label: A single NQLabel.
      score_thres: score threshold

    Returns:
      gold_has_answer, pred_has_answer, is_correct, score
    """

    # There is a gold short answer if gold_label_list not empty and non null
    # answers is over the threshold (sum over annotators).
    gold_has_answer = util.gold_has_short_answer(gold_label_list)

    is_correct = False
    score = pred_label.short_score

    # There is a pred long answer if pred_label is not empty and short answer
    # set is not empty.
    pred_has_answer = pred_label and (
        (not util.is_null_span_list(pred_label.short_answer_span_list))
        or pred_label.yes_no_answer != 'none') and score >= score_thres

    # Both sides have short answers, which contains yes/no questions.
    if gold_has_answer and pred_has_answer:
        if pred_label.yes_no_answer != 'none':  # System thinks its y/n questions.
            for gold_label in gold_label_list:
                if pred_label.yes_no_answer == gold_label.yes_no_answer:
                    is_correct = True
                    break
        else:
            for gold_label in gold_label_list:
                if util.span_set_equal(gold_label.short_answer_span_list,
                                       pred_label.short_answer_span_list):
                    is_correct = True
                    break

    return gold_has_answer, pred_has_answer, is_correct, score
Example #3
0
def score_short_answer(gold_label_list, pred_label):
    """Scores a short answer as correct or not.

  1) First decide if there is a gold short answer with SHORT_NO_NULL_THRESHOLD.
  2) The prediction will get a match if:
     a. There is a gold short answer.
     b. The prediction span *set* match exactly with *one* of the non-null gold
        short answer span *set*.

  Args:
    gold_label_list: A list of NQLabel.
    pred_label: A single NQLabel.

  Returns:
    gold_has_answer, pred_has_answer, f1, score
  """

    # There is a gold short answer if gold_label_list not empty and non null
    # answers is over the threshold (sum over annotators).
    gold_has_answer = util.gold_has_short_answer(gold_label_list)

    # There is a pred long answer if pred_label is not empty and short answer
    # set is not empty.
    pred_has_answer = pred_label and (
        (not util.is_null_span_list(pred_label.short_answer_span_list))
        or pred_label.yes_no_answer != 'none')

    f1 = 0
    p = 0
    r = 0
    score = pred_label.short_score

    # Both sides have short answers, which contains yes/no questions.
    if gold_has_answer and pred_has_answer:
        if pred_label.yes_no_answer != 'none':  # System thinks its y/n questions.
            for gold_label in gold_label_list:
                if pred_label.yes_no_answer == gold_label.yes_no_answer:
                    f1 = 1
                    p = 1
                    r = 1
                    break
        else:
            for gold_label in gold_label_list:
                # if util.span_set_equal(gold_label.short_answer_span_list,
                #                        pred_label.short_answer_span_list):
                #   is_correct = True
                #   break
                gold_set = IntervalSet([Interval(span.start_token_idx, span.end_token_idx) \
                                        for span in gold_label.short_answer_span_list])
                pred_set = IntervalSet([Interval(span.start_token_idx, span.end_token_idx) \
                                        for span in pred_label.short_answer_span_list])
                overlap_set = gold_set & pred_set

                def calc_len(interval_set):
                    sum = 0
                    for span in interval_set:
                        sum += span.upper_bound - span.lower_bound + 1
                    return sum

                precision = safe_divide(calc_len(overlap_set),
                                        calc_len(pred_set))
                recall = safe_divide(calc_len(overlap_set), calc_len(gold_set))

                if safe_divide(2 * precision * recall,
                               precision + recall) > f1:
                    f1 = safe_divide(2 * precision * recall,
                                     precision + recall)
                    p = precision
                    r = recall
    elif not gold_has_answer and not pred_has_answer:
        f1 = 1
        p = 1
        r = 1

    return gold_has_answer, pred_has_answer, f1, p, r, score
Example #4
0
def score_short_answer(gold_label_list, pred_label, threshold=0):
    """Scores a short answer as correct or not.

  1) First decide if there is a gold short answer with SHORT_NO_NULL_THRESHOLD.
  2) The prediction will get a F1 if:
     a. There is a gold short answer.
     b. The prediction span *set* match exactly with *one* of the non-null gold
        short answer span *set*.

  Args:
    gold_label_list: A list of NQLabel.
    pred_label: A single NQLabel.

  Returns:
    gold_has_answer, pred_has_answer, f1, score
  """
    #print(gold_label_list)
    #print(pred_label)
    # There is a gold short answer if gold_label_list not empty and non null
    # answers is over the threshold (sum over annotators).
    gold_has_answer = util.gold_has_short_answer(gold_label_list)

    # There is a pred long answer if pred_label is not empty and short answer
    # set is not empty.

    pred_has_answer = pred_label and (
        (not util.is_null_span_list(pred_label.short_answer_span_list,
                                    pred_label.short_score_list, threshold))
        or pred_label.yes_no_answer != 'none')

    f1 = 0
    p = 0
    r = 0
    # score = pred_label.short_score

    # Both sides have short answers, which contains yes/no questions.
    if gold_has_answer and pred_has_answer:
        if pred_label.yes_no_answer != 'none':  # System thinks its y/n questions.
            for gold_label in gold_label_list:
                if pred_label.yes_no_answer == gold_label.yes_no_answer:
                    f1 = 1
                    p = 1
                    r = 1
                    break
        else:
            # 抽取式答案比对
            for gold_label in gold_label_list:
                gold_set = []
                pred_set = []
                for span, score in zip(pred_label.short_answer_span_list,
                                       pred_label.short_score_list):
                    if score >= threshold:
                        pred_set += [(span.start_token_idx, span.end_token_idx)
                                     ]
                for span in gold_label.short_answer_span_list:
                    gold_set += [(span.start_token_idx, span.end_token_idx)]

                # 这里对gold_Set 进行处理

                #print(gold_set)

                def concat_gold_set(gold_set):
                    def takeFirst(elem):
                        return elem[0]

                    gold_set.sort(key=takeFirst)

                    new_set = []
                    current_start = gold_set[0][0]
                    current_end = gold_set[0][1]
                    for i in range(1, len(gold_set)):
                        start = gold_set[i][0]
                        end = gold_set[i][1]
                        if (start - current_end > 5):
                            new_set.append([current_start, current_end])
                            current_end = end
                            current_start = start
                        else:
                            current_end = end

                    new_set.append([current_start, current_end])
                    return new_set

                gold_set = concat_gold_set(gold_set)

                #print(pred_set)
                #print(gold_set)
                #print("-------------------------------------------")
                def count_same(span_list, interval_set):
                    sum = 0
                    for span in span_list:
                        for interval in interval_set:
                            if span[0] == interval[0] and span[1] == interval[
                                    1]:
                                sum += 1
                                break
                    return sum

                correct_interval = count_same(pred_set, gold_set)
                precision = safe_divide(correct_interval, len(pred_set))
                recall = safe_divide(correct_interval, len(gold_set))

                if safe_divide(2 * precision * recall,
                               precision + recall) > f1:
                    f1 = safe_divide(2 * precision * recall,
                                     precision + recall)
                    p = precision
                    r = recall
    elif not gold_has_answer and not pred_has_answer:
        f1 = 1
        p = 1
        r = 1

    return gold_has_answer, pred_has_answer, f1, p, r