예제 #1
0
파일: tydi_eval.py 프로젝트: wgc20/tydiqa
def compute_macro_f1(answer_stats, prefix=''):
    """Computes F1, precision, recall for a list of answer scores.

  This computes the *language-wise macro F1*. For minimal answers,
  we also compute a partial match score that uses F1, which would be
  included in this computation via `answer_stats`.

  Args:
    answer_stats: List of per-example scores.
    prefix (''): Prefix to prepend to score dictionary.

  Returns:
    Dictionary mapping measurement names to scores.
  """

    has_gold, has_pred, f1, _ = list(zip(*answer_stats))

    macro_precision = eval_utils.safe_divide(sum(f1), sum(has_pred))
    macro_recall = eval_utils.safe_divide(sum(f1), sum(has_gold))
    macro_f1 = eval_utils.safe_divide(2 * macro_precision * macro_recall,
                                      macro_precision + macro_recall)

    return collections.OrderedDict({
        prefix + 'n': len(answer_stats),
        prefix + 'f1': macro_f1,
        prefix + 'precision': macro_precision,
        prefix + 'recall': macro_recall
    })
예제 #2
0
파일: tydi_eval.py 프로젝트: wgc20/tydiqa
def compute_pr_curves(answer_stats, targets=None):
    """Computes PR curve and returns R@P for specific targets.

  The values are computed as follows: find the (precision, recall) point
  with maximum recall and where precision > target.

  This is only relevant if you return the system scores in your predictions.
  You may find this useful when attempting to tune the threshold for your
  system on the dev set before requesting an evaluation on the test set
  via the leaderboard.

  Arguments:
    answer_stats: List of statistic tuples from the answer scores.
    targets (None): List of precision thresholds to target.

  Returns:
    List of table with rows: [target, r, p, score].
  """
    total_f1 = 0
    total_has_pred = 0
    total_has_gold = 0

    # Count the number of gold annotations.
    for has_gold, _, _, _ in answer_stats:
        total_has_gold += has_gold

    # Keep track of the point of maximum recall for each target.

    max_recall = [0 for _ in targets]
    max_precision = [0 for _ in targets]
    max_scores = [None for _ in targets]

    # Only keep track of unique thresholds in this dictionary.
    scores_to_stats = collections.OrderedDict()

    # Loop through every possible threshold and compute precision + recall.
    for has_gold, has_pred, is_correct_or_f1, score in answer_stats:
        if isinstance(is_correct_or_f1, tuple):
            _, _, f1 = is_correct_or_f1
        else:
            f1 = is_correct_or_f1
        total_f1 += f1
        total_has_pred += has_pred

        precision = eval_utils.safe_divide(total_f1, total_has_pred)
        recall = eval_utils.safe_divide(total_f1, total_has_gold)

        # If there are any ties, this will be updated multiple times until the
        # ties are all counted.
        scores_to_stats[score] = [precision, recall]

    best_f1 = 0.0
    best_precision = 0.0
    best_recall = 0.0
    best_threshold = 0.0

    for threshold, (precision, recall) in scores_to_stats.items():
        # Match the thresholds to the find the closest precision above some target.
        for t, target in enumerate(targets):
            if precision >= target and recall > max_recall[t]:
                max_recall[t] = recall
                max_precision[t] = precision
                max_scores[t] = threshold

        # Compute optimal threshold.
        f1 = eval_utils.safe_divide(2 * precision * recall, precision + recall)
        if f1 > best_f1:
            best_f1 = f1
            best_precision = precision
            best_recall = recall
            best_threshold = threshold

    return ((best_f1, best_precision, best_recall, best_threshold),
            list(zip(targets, max_recall, max_precision, max_scores)))