def test_candidates(self, mock_read): with tf.gfile.Open(os.path.join(self._test_dir, 'interaction_03.pbtxt')) as input_file: interaction = text_format.ParseLines(input_file, interaction_pb2.Interaction()) _set_mock_read(mock_read, [interaction]) max_seq_length = 15 tf_example_utils._MAX_NUM_ROWS = 4 tf_example_utils._MAX_NUM_CANDIDATES = 10 pipeline = create_data.build_classifier_pipeline( input_files=['input.tfrecord'], output_files=[self._output_path], config=_ClassifierConfig( vocab_file=os.path.join(self._test_dir, 'vocab.txt'), max_seq_length=max_seq_length, max_column_id=4, max_row_id=4, strip_column_names=False, add_aggregation_candidates=True, ), ) result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline) result.wait_until_finish() output = _read_examples(self._output_path) with tf.gfile.Open(os.path.join(self._test_dir, 'tf_example_03.pbtxt')) as input_file: expected_example = text_format.ParseLines(input_file, tf.train.Example()) actual_example = output[0] logging.info('%s', actual_example) # assertEqual struggles with NaNs inside protos del actual_example.features.feature['numeric_values'] self.assertEqual(actual_example, expected_example)
def test_tfrecord_io(self, mock_read): """Reads from TFRecord and writes to TFRecord.""" with tf.gfile.Open(os.path.join(self._test_dir, 'interaction_03.pbtxt')) as input_file: interaction = text_format.ParseLines(input_file, interaction_pb2.Interaction()) def dummy_read(file_pattern, coder, validate): del file_pattern, coder, validate # Unused. return beam.Create([interaction]) mock_read.side_effect = dummy_read max_seq_length = 15 pipeline = create_data.build_classifier_pipeline( input_files=['input.tfrecord'], output_files=[self._output_path], config=_ClassifierConfig( vocab_file=os.path.join(self._test_dir, 'vocab.txt'), max_seq_length=max_seq_length, max_column_id=4, max_row_id=4, strip_column_names=False, add_aggregation_candidates=False, )) result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline) result.wait_until_finish() output = [] for value in tf.python_io.tf_record_iterator(self._output_path): example = tf.train.Example() example.ParseFromString(value) output.append(example) self.assertLen(output, 1) sid = output[0].features.feature['segment_ids'] self.assertLen(sid.int64_list.value, max_seq_length)
def test_numeric_relations(self, mock_read): input_file = 'interaction_00.pbtxt' expected_counters = { 'Conversion success': 1, 'Example emitted': 1, 'Input question': 1, 'Relation Set Index: 2': 5, 'Relation Set Index: 4': 13, 'Found answers: <= 4': 1, } with tf.gfile.Open(os.path.join(self._test_dir, input_file)) as input_file: interaction = text_format.ParseLines(input_file, interaction_pb2.Interaction()) _set_mock_read(mock_read, [interaction]) max_seq_length = 512 pipeline = create_data.build_classifier_pipeline( input_files=['input.tfrecord'], output_files=[self._output_path], config=_ClassifierConfig( vocab_file=os.path.join(self._test_dir, 'vocab.txt'), max_seq_length=max_seq_length, max_column_id=512, max_row_id=512, strip_column_names=False, add_aggregation_candidates=False, )) result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline) result.wait_until_finish() self.assertEqual( { metric_result.key.metric.name: metric_result.committed for metric_result in result.metrics().query()['counters'] }, expected_counters) output = _read_examples(self._output_path) self.assertLen(output, 1) actual_example = output[0] self.assertIn('numeric_relations', actual_example.features.feature.keys()) relations = actual_example.features.feature[ 'numeric_relations'].int64_list.value with tf.gfile.Open(os.path.join(self._test_dir, 'vocab.txt')) as vocab_file: vocab = [line.strip() for line in vocab_file] inputs = actual_example.features.feature['input_ids'].int64_list.value pairs = [(vocab[input_id], relation) for (input_id, relation) in zip(inputs, relations) if input_id > 0] logging.info('pairs: %s', pairs) self.assertSequenceEqual(pairs, [('[CLS]', 0), ('which', 0), ('cities', 0), ('had', 0), ('less', 0), ('than', 0), ('2', 0), (',', 0), ('000', 0), ('pass', 0), ('##en', 0), ('##ge', 0), ('##rs', 0), ('?', 0), ('[SEP]', 0), ('ran', 0), ('##k', 0), ('city', 0), ('pass', 0), ('##en', 0), ('##ge', 0), ('##rs', 0), ('ran', 0), ('##ki', 0), ('##ng', 0), ('air', 0), ('##li', 0), ('##ne', 0), ('1', 4), ('united', 0), ('states', 0), (',', 0), ('los', 0), ('angeles', 0), ('14', 2), (',', 2), ('7', 2), ('##4', 2), ('##9', 2), ('[EMPTY]', 0), ('al', 0), ('##as', 0), ('##ka', 0), ('air', 0), ('##li', 0), ('##ne', 0), ('##s', 0), ('2', 4), ('united', 0), ('states', 0), (',', 0), ('h', 0), ('##ous', 0), ('##ton', 0), ('5', 2), (',', 2), ('4', 2), ('##6', 2), ('##5', 2), ('[EMPTY]', 0), ('united', 0), ('e', 0), ('##x', 0), ('##p', 0), ('##re', 0), ('##s', 0), ('##s', 0), ('3', 4), ('canada', 0), (',', 0), ('c', 0), ('##al', 0), ('##ga', 0), ('##ry', 0), ('3', 2), (',', 2), ('7', 2), ('##6', 2), ('##1', 2), ('[EMPTY]', 0), ('air', 0), ('t', 0), ('##ra', 0), ('##ns', 0), ('##a', 0), ('##t', 0), (',', 0), ('west', 0), ('##j', 0), ('##et', 0), ('4', 4), ('canada', 0), (',', 0), ('s', 0), ('##as', 0), ('##ka', 0), ('##to', 0), ('##on', 0), ('2', 2), (',', 2), ('28', 2), ('##2', 2), ('4', 0), ('[EMPTY]', 0), ('5', 4), ('canada', 0), (',', 0), ('van', 0), ('##co', 0), ('##u', 0), ('##ve', 0), ('##r', 0), ('2', 2), (',', 2), ('10', 2), ('##3', 2), ('[EMPTY]', 0), ('air', 0), ('t', 0), ('##ra', 0), ('##ns', 0), ('##a', 0), ('##t', 0), ('6', 4), ('united', 0), ('states', 0), (',', 0), ('p', 0), ('##h', 0), ('##o', 0), ('##en', 0), ('##i', 0), ('##x', 0), ('1', 4), (',', 4), ('8', 4), ('##2', 4), ('##9', 4), ('1', 0), ('us', 0), ('air', 0), ('##w', 0), ('##a', 0), ('##y', 0), ('##s', 0), ('7', 4), ('canada', 0), (',', 0), ('to', 0), ('##ro', 0), ('##nt', 0), ('##o', 0), ('1', 4), (',', 4), ('20', 4), ('##2', 4), ('1', 0), ('air', 0), ('t', 0), ('##ra', 0), ('##ns', 0), ('##a', 0), ('##t', 0), (',', 0), ('can', 0), ('##j', 0), ('##et', 0), ('8', 4), ('canada', 0), (',', 0), ('ed', 0), ('##m', 0), ('##on', 0), ('##ton', 0), ('11', 4), ('##0', 4), ('[EMPTY]', 0), ('[EMPTY]', 0), ('9', 4), ('united', 0), ('states', 0), (',', 0), ('o', 0), ('##a', 0), ('##k', 0), ('##land', 0), ('10', 4), ('##7', 4), ('[EMPTY]', 0), ('[EMPTY]', 0)])
def test_gracefully_handle_big_examples(self, max_seq_length, max_column_id, max_row_id, expected_counters, mock_read): with tf.gfile.Open(os.path.join(self._test_dir, 'interaction_02.pbtxt')) as input_file: interaction = text_format.ParseLines(input_file, interaction_pb2.Interaction()) _set_mock_read(mock_read, [interaction]) pipeline = create_data.build_classifier_pipeline( input_files=['input.tfrecord'], output_files=[self._output_path], config=_ClassifierConfig( vocab_file=self._vocab_path, max_seq_length=60 if max_seq_length is None else max_seq_length, max_column_id=5 if max_column_id is None else max_column_id, max_row_id=10 if max_row_id is None else max_row_id, strip_column_names=False, add_aggregation_candidates=False, )) result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline) result.wait_until_finish() self.assertEqual( { metric_result.key.metric.name: metric_result.committed for metric_result in result.metrics().query()['counters'] }, expected_counters) if max_seq_length is None and max_column_id is None and max_row_id is None: output = _read_examples(self._output_path) with tf.gfile.Open( os.path.join(self._test_dir, 'tf_example_02.pbtxt')) as input_file: expected_example = text_format.ParseLines( input_file, tf.train.Example()) with tf.gfile.Open( os.path.join(self._test_dir, 'tf_example_02_conv.pbtxt')) as input_file: expected_conversational_example = text_format.ParseLines( input_file, tf.train.Example()) self.assertLen(output, 2) actual_example = output[0] del actual_example.features.feature['column_ranks'] del actual_example.features.feature['inv_column_ranks'] del actual_example.features.feature['numeric_relations'] del actual_example.features.feature['numeric_values'] del actual_example.features.feature['numeric_values_scale'] del actual_example.features.feature['question_id_ints'] # assertEqual struggles with NaNs inside protos del actual_example.features.feature['answer'] self.assertEqual(actual_example, expected_example) actual_example = output[1] del actual_example.features.feature['column_ranks'] del actual_example.features.feature['inv_column_ranks'] del actual_example.features.feature['numeric_relations'] del actual_example.features.feature['numeric_values'] del actual_example.features.feature['numeric_values_scale'] del actual_example.features.feature['question_id_ints'] # assertEqual struggles with NaNs inside protos del actual_example.features.feature['answer'] self.assertEqual(actual_example, expected_conversational_example)