def test_qa_metrics(self):
        with self.assertRaisesRegex(
                ValueError, "Number of targets and predictions must match."):
            qa_utils.qa_metrics([["answer"]] * 6, ["answer"] * 5)

        self.assertDictEqual(
            qa_utils.qa_metrics([["answer"]] * 5, ["answer"] * 5), {
                "em": 100.0,
                "f1": 100.0
            })

        self.assertDictEqual(
            qa_utils.qa_metrics(
                [
                    ["big moose", "hippo"],
                    ["correct1"],
                    ["correct2.1", "correct2.2"],
                    ["a", "b"],
                ],
                [
                    "a big moose‘",
                    "wrong",
                    "correct2.2",
                    "c",
                ],
            ),
            {
                "em": 25.,
                "f1": 35.
            },
        )
コード例 #2
0
def trivia_qa(targets, predictions):
    """Computes TriviaQA metrics, maximizing over answers per question.

  Args:
    targets: list of lists of strings
    predictions: list of strings

  Returns:
    dict with score_key: squad score across all targets and predictions
  """
    targets = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in targets]
    predictions = [qa_utils.normalize_trivia_qa(p) for p in predictions]
    return qa_utils.qa_metrics(targets, predictions)
コード例 #3
0
ファイル: metrics.py プロジェクト: yyht/multilingual-t5
def mlqa(targets, predictions, lang):
    """Computes MLQA metrics, maximizing over answers per question.

  Args:
    targets: list of lists of strings
    predictions: list of strings
    lang: ISO code of language

  Returns:
    dict with score_key: squad score across all targets and predictions
  """
    targets = [[normalize_mlqa(t, lang) for t in u] for u in targets]
    predictions = [normalize_mlqa(p, lang) for p in predictions]
    return qa_utils.qa_metrics(targets, predictions)
コード例 #4
0
ファイル: metrics.py プロジェクト: yly1994/multilingual-t5
def mlqa(targets, predictions, lang=None):
    """Computes MLQA metrics, maximizing over answers per question.

  Args:
    targets: list of lists of strings
    predictions: list of strings
    lang: ISO code of language

  Returns:
    dict with score_key: squad score across all targets and predictions
  """
    assert lang is not None
    punct = {
        chr(i)
        for i in range(sys.maxunicode)
        if unicodedata.category(chr(i)).startswith('P')
    }.union(string.punctuation)
    targets = [[normalize_mlqa(t, lang, punct) for t in u] for u in targets]
    predictions = [normalize_mlqa(p, lang, punct) for p in predictions]
    return qa_utils.qa_metrics(targets, predictions)