Exemplo n.º 1
0
 def test_token_selection(self):
     prediction = {
         'input_ids': np.array([0, 1, 2, 3, 4]),
         'probabilities': np.array([1.0, 0.5, 0.2, 0.5, 0.3]),
         'column_ids': np.array([0, 1, 2, 3, 4]),
         'row_ids': np.array([0, 1, 1, 2, 2]),
         'segment_ids': np.array([0, 1, 1, 1, 1]),
     }
     answers = prediction_utils._get_token_answers(
         prediction,
         cell_classification_threshold=0.49999,
     )
     logging.info(answers)
     self.assertEqual(answers, [
         prediction_utils.TokenAnswer(column_index=0,
                                      row_index=0,
                                      begin_token_index=0,
                                      end_token_index=1,
                                      token_ids=[1],
                                      score=0.5),
         prediction_utils.TokenAnswer(column_index=2,
                                      row_index=1,
                                      begin_token_index=0,
                                      end_token_index=1,
                                      token_ids=[3],
                                      score=0.5),
     ])
Exemplo n.º 2
0
 def test_span_selection_with_row_boundary(self):
     prediction = {
         'input_ids': np.array([1, 2, 3, 4]),
         'span_indexes': np.array([[1, 1], [1, 2], [2, 1]]),
         'span_logits': np.array([-100.0, 10.0, 5.0]),
         'column_ids': np.array([2, 2, 2, 2]),
         'row_ids': np.array([1, 2, 2, 2]),
     }
     answers = prediction_utils._get_token_answers(
         prediction, cell_classification_threshold=0.5)
     self.assertEqual(answers, [
         prediction_utils.TokenAnswer(
             column_index=1,
             row_index=1,
             begin_token_index=0,
             end_token_index=2,
             token_ids=[2, 3],
             score=10.0,
         )
     ])
Exemplo n.º 3
0
 def test_simple_token_answer(self):
   with tempfile.TemporaryDirectory() as input_dir:
     vocab_file = os.path.join(input_dir, "vocab.txt")
     _create_vocab(vocab_file, ["answer", "wrong"])
     interactions = [
         text_format.Parse(
             """
           table {
             rows {
               cells { text: "WRONG WRONG" }
             }
           }
           questions {
             id: "example_id-0_0"
             answer {
               class_index: 1
               answer_texts: "OTHER"
             }
             alternative_answers {
               answer_texts: "ANSWER"
             }
           }
         """, interaction_pb2.Interaction()),
         text_format.Parse(
             """
           table {
             rows {
               cells { text: "ANSWER WRONG" }
             }
           }
           questions {
             id: "example_id-1_0"
             answer {
               class_index: 1
               answer_texts: "ANSWER"
             }
           }
         """, interaction_pb2.Interaction())
     ]
     predictions = [
         {
             "question_id":
                 "example_id-0_0",
             "logits_cls":
                 "0",
             "answers":
                 prediction_utils.token_answers_to_text([
                     prediction_utils.TokenAnswer(
                         row_index=-1,
                         column_index=-1,
                         begin_token_index=-1,
                         end_token_index=-1,
                         token_ids=[1],
                         score=10.0,
                     )
                 ]),
         },
         {
             "question_id":
                 "example_id-1_0",
             "logits_cls":
                 "1",
             "answers":
                 prediction_utils.token_answers_to_text([
                     prediction_utils.TokenAnswer(
                         row_index=0,
                         column_index=0,
                         begin_token_index=-1,
                         end_token_index=-1,
                         token_ids=[6, 7],
                         score=10.0,
                     )
                 ]),
         },
     ]
     result = e2e_eval_utils._evaluate_retrieval_e2e(
         vocab_file,
         interactions,
         predictions,
     )
     logging.info("result: %s", result.to_dict())
     self.assertEqual(
         {
             "answer_accuracy": 0.0,
             "answer_precision": 0.0,
             "answer_token_f1": 0.6666666666666666,
             "answer_token_precision": 0.5,
             "answer_token_recall": 1.0,
             "oracle_answer_accuracy": 0.0,
             "oracle_answer_token_f1": 0.6666666666666666,
             "table_accuracy": 1.0,
             "table_precision": 1.0,
             "table_recall": 1.0,
             "answer_accuracy_table": None,
             "answer_accuracy_passage": None,
             "answer_token_f1_table": None,
             "answer_token_f1_passage": None,
         }, result.to_dict())