def test_token_selection(self): prediction = { 'input_ids': np.array([0, 1, 2, 3, 4]), 'probabilities': np.array([1.0, 0.5, 0.2, 0.5, 0.3]), 'column_ids': np.array([0, 1, 2, 3, 4]), 'row_ids': np.array([0, 1, 1, 2, 2]), 'segment_ids': np.array([0, 1, 1, 1, 1]), } answers = prediction_utils._get_token_answers( prediction, cell_classification_threshold=0.49999, ) logging.info(answers) self.assertEqual(answers, [ prediction_utils.TokenAnswer(column_index=0, row_index=0, begin_token_index=0, end_token_index=1, token_ids=[1], score=0.5), prediction_utils.TokenAnswer(column_index=2, row_index=1, begin_token_index=0, end_token_index=1, token_ids=[3], score=0.5), ])
def test_span_selection_with_row_boundary(self): prediction = { 'input_ids': np.array([1, 2, 3, 4]), 'span_indexes': np.array([[1, 1], [1, 2], [2, 1]]), 'span_logits': np.array([-100.0, 10.0, 5.0]), 'column_ids': np.array([2, 2, 2, 2]), 'row_ids': np.array([1, 2, 2, 2]), } answers = prediction_utils._get_token_answers( prediction, cell_classification_threshold=0.5) self.assertEqual(answers, [ prediction_utils.TokenAnswer( column_index=1, row_index=1, begin_token_index=0, end_token_index=2, token_ids=[2, 3], score=10.0, ) ])
def test_end_to_end(self, span_prediction): estimator = self._create_estimator(span_prediction=span_prediction) def _input_fn(params): return table_dataset_test_utils.create_random_dataset( num_examples=params['batch_size'] * 2, batch_size=params['batch_size'], repeat=False, generator_kwargs=self._generator_kwargs()) result = estimator.predict(_input_fn) num_examples = 0 for prediction in result: if span_prediction != _SpanPredictionMode.NONE: self.assertIn('span_logits', prediction) self.assertIn('span_indexes', prediction) logging.info('prediction: %s', prediction) _ = prediction_utils._get_token_answers( prediction, cell_classification_threshold=0.5, ) num_examples += 1 self.assertEqual(num_examples, _BATCH_SIZE * 2)