def test_qa_metrics(self): with self.assertRaisesRegex( ValueError, "Number of targets and predictions must match."): qa_utils.qa_metrics([["answer"]] * 6, ["answer"] * 5) self.assertDictEqual( qa_utils.qa_metrics([["answer"]] * 5, ["answer"] * 5), { "em": 100.0, "f1": 100.0 }) self.assertDictEqual( qa_utils.qa_metrics( [ ["big moose", "hippo"], ["correct1"], ["correct2.1", "correct2.2"], ["a", "b"], ], [ "a big moose‘", "wrong", "correct2.2", "c", ], ), { "em": 25., "f1": 35. }, )
def trivia_qa(targets, predictions): """Computes TriviaQA metrics, maximizing over answers per question. Args: targets: list of lists of strings predictions: list of strings Returns: dict with score_key: squad score across all targets and predictions """ targets = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in targets] predictions = [qa_utils.normalize_trivia_qa(p) for p in predictions] return qa_utils.qa_metrics(targets, predictions)
def mlqa(targets, predictions, lang): """Computes MLQA metrics, maximizing over answers per question. Args: targets: list of lists of strings predictions: list of strings lang: ISO code of language Returns: dict with score_key: squad score across all targets and predictions """ targets = [[normalize_mlqa(t, lang) for t in u] for u in targets] predictions = [normalize_mlqa(p, lang) for p in predictions] return qa_utils.qa_metrics(targets, predictions)
def mlqa(targets, predictions, lang=None): """Computes MLQA metrics, maximizing over answers per question. Args: targets: list of lists of strings predictions: list of strings lang: ISO code of language Returns: dict with score_key: squad score across all targets and predictions """ assert lang is not None punct = { chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith('P') }.union(string.punctuation) targets = [[normalize_mlqa(t, lang, punct) for t in u] for u in targets] predictions = [normalize_mlqa(p, lang, punct) for p in predictions] return qa_utils.qa_metrics(targets, predictions)