def test_normalize_trivia_qa(self): self.assertEqual( qa_utils.normalize_trivia_qa( "`Needs\tA_LOT of the 'normalization'.\"‘"), "needs lot of normalization", ) self.assertEqual( qa_utils.normalize_trivia_qa("needs no normalization"), "needs no normalization", )
def trivia_qa(targets, predictions): """Computes TriviaQA metrics, maximizing over answers per question. Args: targets: list of lists of strings predictions: list of strings Returns: dict with score_key: squad score across all targets and predictions """ targets = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in targets] predictions = [qa_utils.normalize_trivia_qa(p) for p in predictions] return qa_utils.qa_metrics(targets, predictions)