def test_overlap_in_correct_cases(self):
     assert get_metrics(["Green bay packers"],
                        ["Green bay packers"]) == (1.0, 1.0)
     assert get_metrics(["Green bay", "packers"],
                        ["Green bay", "packers"]) == (1.0, 1.0)
     assert get_metrics(["Green", "bay", "packers"],
                        ["Green", "bay", "packers"]) == (1.0, 1.0)
    def test_multi_span_overlap_in_incorrect_cases(self):
        # only consider bags with matching numbers if they are present
        # F1 scores of:     1.0        2/3   0.0   0.0   0.0   0.0
        # Average them to get F1 of 0.28
        assert get_metrics(
            ["78-yard", "56", "28", "40", "44", "touchdown"],
            ["78-yard", "56 yard", "1 yard touchdown"],
        ) == (0.0, 0.28)

        # two copies of same value will account for only one match (using optimal 1-1 bag alignment)
        assert get_metrics(["23", "23 yard"],
                           ["23-yard", "56 yards"]) == (0.0, 0.5)

        # matching done at individual span level and not pooled into one global bag
        assert get_metrics(["John Karman", "Joe Hardy"],
                           ["Joe Karman", "John Hardy"]) == (0.0, 0.5)

        # macro-averaging F1 over spans
        assert get_metrics(["ottoman", "Kantakouzenous"],
                           ["ottoman", "army of Kantakouzenous"]) == (0.0,
                                                                      0.75)
 def test_simple_overlap_in_incorrect_cases(self):
     assert get_metrics([""], ["army"]) == (0.0, 0.0)
     assert get_metrics(["packers"], ["Green bay packers"]) == (0.0, 0.5)
     assert get_metrics(["packers"], ["Green bay"]) == (0.0, 0.0)
     # if the numbers in the span don't match f1 is 0
     assert get_metrics(["yard"], ["36 yard td"]) == (0.0, 0.0)
     assert get_metrics(["23 yards"], ["43 yards"]) == (0.0, 0.0)
     # however, if number matches its not given extra weight over the non-functional words
     assert get_metrics(["56 yards"], ["56 yd"]) == (0.0, 0.5)
     assert get_metrics(["26"], ["26 yard td"]) == (0.0, 0.5)
def evaluate_json(annotations: Dict[str, Any],
                  predicted_answers: Dict[str, Any]) -> Tuple[float, float]:
    """
    Takes gold annotations and predicted answers and  evaluates the predictions for each question
    in the gold annotations.  Both JSON dictionaries must have query_id keys, which are used to
    match predictions to gold annotations.

    The ``predicted_answers`` JSON must be a dictionary keyed by query id, where the value is a
    list of strings (or just one string) that is the answer.
    The ``annotations`` are assumed to have either the format of the dev set in the Quoref data release, or the
    same format as the predicted answers file.
    """
    instance_exact_match = []
    instance_f1 = []
    if "data" in annotations:
        # We're looking at annotations in the original data format. Let's extract the answers.
        annotated_answers = _get_answers_from_data(annotations)
    else:
        annotated_answers = annotations
    for query_id, candidate_answers in annotated_answers.items():
        max_em_score = 0.0
        max_f1_score = 0.0
        if query_id in predicted_answers:
            predicted = predicted_answers[query_id]
            gold_answer = tuple(candidate_answers)
            em_score, f1_score = drop_eval.get_metrics(predicted, gold_answer)
            if gold_answer[0].strip() != "":
                max_em_score = max(max_em_score, em_score)
                max_f1_score = max(max_f1_score, f1_score)
        else:
            print("Missing prediction for question: {}".format(query_id))
            max_em_score = 0.0
            max_f1_score = 0.0
        instance_exact_match.append(max_em_score)
        instance_f1.append(max_f1_score)

    global_em = np.mean(instance_exact_match)
    global_f1 = np.mean(instance_f1)
    print("Exact-match accuracy {0:.2f}".format(global_em * 100))
    print("F1 score {0:.2f}".format(global_f1 * 100))
    print("{0:.2f}   &   {1:.2f}".format(global_em * 100, global_f1 * 100))
    return global_em, global_f1
    def test_metric_is_length_aware(self):
        # Overall F1 should be mean([1.0, 0.0])
        assert get_metrics(predicted=["td"], gold=["td", "td"]) == (0.0, 0.5)
        assert get_metrics("td", ["td", "td"]) == (0.0, 0.5)
        # Overall F1 should be mean ([1.0, 0.0]) = 0.5
        assert get_metrics(predicted=["td", "td"], gold=["td"]) == (0.0, 0.5)
        assert get_metrics(predicted=["td", "td"], gold="td") == (0.0, 0.5)

        # F1 score is mean([0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
        assert get_metrics(predicted=[
            "the", "fat", "cat", "the fat", "fat cat", "the fat cat"
        ],
                           gold=["cat"]) == (0.0, 0.17)
        assert get_metrics(
            predicted=["cat"],
            gold=["the", "fat", "cat", "the fat", "fat cat",
                  "the fat cat"]) == (0.0, 0.17)
        # F1 score is mean([1.0, 0.5, 0.0, 0.0, 0.0, 0.0])
        assert get_metrics(
            predicted=[
                "the", "fat", "cat", "the fat", "fat cat", "the fat cat"
            ],
            gold=["cat", "cat dog"],
        ) == (0.0, 0.25)
 def test_casing_is_ignored(self):
     assert get_metrics(["This was a triumph"],
                        ["tHIS Was A TRIUMPH"]) == (1.0, 1.0)
 def test_splitting_on_hyphens(self):
     assert get_metrics(["78-yard"], ["78 yard"]) == (1.0, 1.0)
     assert get_metrics(["78 yard"], ["78-yard"]) == (1.0, 1.0)
     assert get_metrics(["78"], ["78-yard"]) == (0.0, 0.67)
     assert get_metrics(["78-yard"], ["78"]) == (0.0, 0.67)
 def test_periods_commas_and_spaces_are_ignored(self):
     assert get_metrics(["Per.i.o.d...."],
                        [".P....e.r,,i;;;o...d,,"]) == (1.0, 1.0)
     assert get_metrics(["Spa     c  e   s     "],
                        ["    Spa c     e s"]) == (1.0, 1.0)
 def test_f1_ignores_word_order(self):
     assert get_metrics(["John Elton"], ["Elton John"]) == (0.0, 1.0)
     assert get_metrics(["50 yard"], ["yard 50"]) == (0.0, 1.0)
     assert get_metrics(["order word right"],
                        ["right word order"]) == (0.0, 1.0)
 def test_articles_are_ignored(self):
     assert get_metrics(["td"], ["the td"]) == (1.0, 1.0)
     assert get_metrics(["the a NOT an ARTICLE the an a"],
                        ["NOT ARTICLE"]) == (1.0, 1.0)
 def test_float_numbers(self):
     assert get_metrics(["78"], ["78.0"]) == (1.0, 1.0)
 def test_order_invariance(self):
     assert get_metrics(["a"], ["a", "b"]) == (0, 0.5)
     assert get_metrics(["b"], ["a", "b"]) == (0, 0.5)
     assert get_metrics(["b"], ["b", "a"]) == (0, 0.5)