def test_token_selection(self): prediction = { 'input_ids': np.array([0, 1, 2, 3, 4]), 'probabilities': np.array([1.0, 0.5, 0.2, 0.5, 0.3]), 'column_ids': np.array([0, 1, 2, 3, 4]), 'row_ids': np.array([0, 1, 1, 2, 2]), 'segment_ids': np.array([0, 1, 1, 1, 1]), } answers = prediction_utils._get_token_answers( prediction, cell_classification_threshold=0.49999, ) logging.info(answers) self.assertEqual(answers, [ prediction_utils.TokenAnswer(column_index=0, row_index=0, begin_token_index=0, end_token_index=1, token_ids=[1], score=0.5), prediction_utils.TokenAnswer(column_index=2, row_index=1, begin_token_index=0, end_token_index=1, token_ids=[3], score=0.5), ])
def test_span_selection_with_row_boundary(self): prediction = { 'input_ids': np.array([1, 2, 3, 4]), 'span_indexes': np.array([[1, 1], [1, 2], [2, 1]]), 'span_logits': np.array([-100.0, 10.0, 5.0]), 'column_ids': np.array([2, 2, 2, 2]), 'row_ids': np.array([1, 2, 2, 2]), } answers = prediction_utils._get_token_answers( prediction, cell_classification_threshold=0.5) self.assertEqual(answers, [ prediction_utils.TokenAnswer( column_index=1, row_index=1, begin_token_index=0, end_token_index=2, token_ids=[2, 3], score=10.0, ) ])
def test_simple_token_answer(self): with tempfile.TemporaryDirectory() as input_dir: vocab_file = os.path.join(input_dir, "vocab.txt") _create_vocab(vocab_file, ["answer", "wrong"]) interactions = [ text_format.Parse( """ table { rows { cells { text: "WRONG WRONG" } } } questions { id: "example_id-0_0" answer { class_index: 1 answer_texts: "OTHER" } alternative_answers { answer_texts: "ANSWER" } } """, interaction_pb2.Interaction()), text_format.Parse( """ table { rows { cells { text: "ANSWER WRONG" } } } questions { id: "example_id-1_0" answer { class_index: 1 answer_texts: "ANSWER" } } """, interaction_pb2.Interaction()) ] predictions = [ { "question_id": "example_id-0_0", "logits_cls": "0", "answers": prediction_utils.token_answers_to_text([ prediction_utils.TokenAnswer( row_index=-1, column_index=-1, begin_token_index=-1, end_token_index=-1, token_ids=[1], score=10.0, ) ]), }, { "question_id": "example_id-1_0", "logits_cls": "1", "answers": prediction_utils.token_answers_to_text([ prediction_utils.TokenAnswer( row_index=0, column_index=0, begin_token_index=-1, end_token_index=-1, token_ids=[6, 7], score=10.0, ) ]), }, ] result = e2e_eval_utils._evaluate_retrieval_e2e( vocab_file, interactions, predictions, ) logging.info("result: %s", result.to_dict()) self.assertEqual( { "answer_accuracy": 0.0, "answer_precision": 0.0, "answer_token_f1": 0.6666666666666666, "answer_token_precision": 0.5, "answer_token_recall": 1.0, "oracle_answer_accuracy": 0.0, "oracle_answer_token_f1": 0.6666666666666666, "table_accuracy": 1.0, "table_precision": 1.0, "table_recall": 1.0, "answer_accuracy_table": None, "answer_accuracy_passage": None, "answer_token_f1_table": None, "answer_token_f1_passage": None, }, result.to_dict())