Ejemplo n.º 1
0
  def testSpanSetEqual(self):
    """Set span set equal."""
    span_a1 = util.Span(-1, -1, 100, 102)
    span_a2 = util.Span(-1, -1, 100, 102)
    span_b = util.Span(-1, -1, 101, 105)
    null_span = util.Span(-1, -1, -1, -1)

    self.assertTrue(util.span_set_equal([span_a1, span_b], [span_a2, span_b]))

    self.assertTrue(
        util.span_set_equal([span_a1, span_b], [span_a2, span_b, null_span]))

    self.assertFalse(
        util.span_set_equal([span_a1], [span_a2, span_b, null_span]))
def score_short_answer(gold_label_list, pred_label, score_thres):
    """Scores a short answer as correct or not.

    1) First decide if there is a gold short answer with SHORT_NO_NULL_THRESHOLD.
    2) The prediction will get a match if:
       a. There is a gold short answer.
       b. The prediction span *set* match exactly with *one* of the non-null gold
          short answer span *set*.

    Args:
      gold_label_list: A list of NQLabel.
      pred_label: A single NQLabel.
      score_thres: score threshold

    Returns:
      gold_has_answer, pred_has_answer, is_correct, score
    """

    # There is a gold short answer if gold_label_list not empty and non null
    # answers is over the threshold (sum over annotators).
    gold_has_answer = util.gold_has_short_answer(gold_label_list)

    is_correct = False
    score = pred_label.short_score

    # There is a pred long answer if pred_label is not empty and short answer
    # set is not empty.
    pred_has_answer = pred_label and (
        (not util.is_null_span_list(pred_label.short_answer_span_list))
        or pred_label.yes_no_answer != 'none') and score >= score_thres

    # Both sides have short answers, which contains yes/no questions.
    if gold_has_answer and pred_has_answer:
        if pred_label.yes_no_answer != 'none':  # System thinks its y/n questions.
            for gold_label in gold_label_list:
                if pred_label.yes_no_answer == gold_label.yes_no_answer:
                    is_correct = True
                    break
        else:
            for gold_label in gold_label_list:
                if util.span_set_equal(gold_label.short_answer_span_list,
                                       pred_label.short_answer_span_list):
                    is_correct = True
                    break

    return gold_has_answer, pred_has_answer, is_correct, score