def test_token_selection(self):
     prediction = {
         'input_ids': np.array([0, 1, 2, 3, 4]),
         'probabilities': np.array([1.0, 0.5, 0.2, 0.5, 0.3]),
         'column_ids': np.array([0, 1, 2, 3, 4]),
         'row_ids': np.array([0, 1, 1, 2, 2]),
         'segment_ids': np.array([0, 1, 1, 1, 1]),
     }
     answers = prediction_utils._get_token_answers(
         prediction,
         cell_classification_threshold=0.49999,
     )
     logging.info(answers)
     self.assertEqual(answers, [
         prediction_utils.TokenAnswer(column_index=0,
                                      row_index=0,
                                      begin_token_index=0,
                                      end_token_index=1,
                                      token_ids=[1],
                                      score=0.5),
         prediction_utils.TokenAnswer(column_index=2,
                                      row_index=1,
                                      begin_token_index=0,
                                      end_token_index=1,
                                      token_ids=[3],
                                      score=0.5),
     ])
 def test_span_selection_with_row_boundary(self):
     prediction = {
         'input_ids': np.array([1, 2, 3, 4]),
         'span_indexes': np.array([[1, 1], [1, 2], [2, 1]]),
         'span_logits': np.array([-100.0, 10.0, 5.0]),
         'column_ids': np.array([2, 2, 2, 2]),
         'row_ids': np.array([1, 2, 2, 2]),
     }
     answers = prediction_utils._get_token_answers(
         prediction, cell_classification_threshold=0.5)
     self.assertEqual(answers, [
         prediction_utils.TokenAnswer(
             column_index=1,
             row_index=1,
             begin_token_index=0,
             end_token_index=2,
             token_ids=[2, 3],
             score=10.0,
         )
     ])
    def test_end_to_end(self, span_prediction):
        estimator = self._create_estimator(span_prediction=span_prediction)

        def _input_fn(params):
            return table_dataset_test_utils.create_random_dataset(
                num_examples=params['batch_size'] * 2,
                batch_size=params['batch_size'],
                repeat=False,
                generator_kwargs=self._generator_kwargs())

        result = estimator.predict(_input_fn)
        num_examples = 0
        for prediction in result:
            if span_prediction != _SpanPredictionMode.NONE:
                self.assertIn('span_logits', prediction)
                self.assertIn('span_indexes', prediction)
            logging.info('prediction: %s', prediction)
            _ = prediction_utils._get_token_answers(
                prediction,
                cell_classification_threshold=0.5,
            )
            num_examples += 1
        self.assertEqual(num_examples, _BATCH_SIZE * 2)