示例#1
0
    def test_candidates(self, mock_read):

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_03.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])
        max_seq_length = 15

        tf_example_utils._MAX_NUM_ROWS = 4
        tf_example_utils._MAX_NUM_CANDIDATES = 10

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=4,
                max_row_id=4,
                strip_column_names=False,
                add_aggregation_candidates=True,
            ),
        )

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        output = _read_examples(self._output_path)

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'tf_example_03.pbtxt')) as input_file:
            expected_example = text_format.ParseLines(input_file,
                                                      tf.train.Example())

        actual_example = output[0]
        logging.info('%s', actual_example)
        # assertEqual struggles with NaNs inside protos
        del actual_example.features.feature['numeric_values']
        self.assertEqual(actual_example, expected_example)
示例#2
0
    def test_tfrecord_io(self, mock_read):
        """Reads from TFRecord and writes to TFRecord."""

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_03.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        def dummy_read(file_pattern, coder, validate):
            del file_pattern, coder, validate  # Unused.
            return beam.Create([interaction])

        mock_read.side_effect = dummy_read
        max_seq_length = 15

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=4,
                max_row_id=4,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        output = []
        for value in tf.python_io.tf_record_iterator(self._output_path):
            example = tf.train.Example()
            example.ParseFromString(value)
            output.append(example)

        self.assertLen(output, 1)
        sid = output[0].features.feature['segment_ids']
        self.assertLen(sid.int64_list.value, max_seq_length)
示例#3
0
    def test_numeric_relations(self, mock_read):
        input_file = 'interaction_00.pbtxt'
        expected_counters = {
            'Conversion success': 1,
            'Example emitted': 1,
            'Input question': 1,
            'Relation Set Index: 2': 5,
            'Relation Set Index: 4': 13,
            'Found answers: <= 4': 1,
        }

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        input_file)) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        max_seq_length = 512

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=512,
                max_row_id=512,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        output = _read_examples(self._output_path)

        self.assertLen(output, 1)
        actual_example = output[0]

        self.assertIn('numeric_relations',
                      actual_example.features.feature.keys())
        relations = actual_example.features.feature[
            'numeric_relations'].int64_list.value

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'vocab.txt')) as vocab_file:
            vocab = [line.strip() for line in vocab_file]
        inputs = actual_example.features.feature['input_ids'].int64_list.value
        pairs = [(vocab[input_id], relation)
                 for (input_id, relation) in zip(inputs, relations)
                 if input_id > 0]
        logging.info('pairs: %s', pairs)
        self.assertSequenceEqual(pairs, [('[CLS]', 0), ('which', 0),
                                         ('cities', 0),
                                         ('had', 0), ('less', 0), ('than', 0),
                                         ('2', 0), (',', 0), ('000', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('?', 0), ('[SEP]', 0),
                                         ('ran', 0), ('##k', 0), ('city', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('ran', 0), ('##ki', 0),
                                         ('##ng', 0), ('air', 0), ('##li', 0),
                                         ('##ne', 0), ('1', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('los', 0),
                                         ('angeles', 0), ('14', 2), (',', 2),
                                         ('7', 2), ('##4', 2), ('##9', 2),
                                         ('[EMPTY]', 0),
                                         ('al', 0), ('##as', 0), ('##ka', 0),
                                         ('air', 0), ('##li', 0), ('##ne', 0),
                                         ('##s', 0), ('2', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('h', 0),
                                         ('##ous', 0), ('##ton', 0), ('5', 2),
                                         (',', 2), ('4', 2), ('##6', 2),
                                         ('##5', 2), ('[EMPTY]', 0),
                                         ('united', 0), ('e', 0), ('##x', 0),
                                         ('##p', 0), ('##re', 0), ('##s', 0),
                                         ('##s', 0), ('3', 4), ('canada', 0),
                                         (',', 0), ('c', 0), ('##al', 0),
                                         ('##ga', 0), ('##ry', 0), ('3', 2),
                                         (',', 2), ('7', 2), ('##6', 2),
                                         ('##1', 2), ('[EMPTY]', 0),
                                         ('air', 0), ('t', 0), ('##ra', 0),
                                         ('##ns', 0), ('##a', 0), ('##t', 0),
                                         (',', 0), ('west', 0), ('##j', 0),
                                         ('##et', 0), ('4', 4), ('canada', 0),
                                         (',', 0), ('s', 0), ('##as', 0),
                                         ('##ka', 0), ('##to', 0), ('##on', 0),
                                         ('2', 2), (',', 2), ('28', 2),
                                         ('##2', 2), ('4', 0), ('[EMPTY]', 0),
                                         ('5', 4), ('canada', 0), (',', 0),
                                         ('van', 0), ('##co', 0), ('##u', 0),
                                         ('##ve', 0), ('##r', 0), ('2', 2),
                                         (',', 2), ('10', 2), ('##3', 2),
                                         ('[EMPTY]', 0), ('air', 0), ('t', 0),
                                         ('##ra', 0), ('##ns', 0), ('##a', 0),
                                         ('##t', 0), ('6', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('p', 0),
                                         ('##h', 0), ('##o', 0), ('##en', 0),
                                         ('##i', 0), ('##x', 0), ('1', 4),
                                         (',', 4), ('8', 4), ('##2', 4),
                                         ('##9', 4), ('1', 0), ('us', 0),
                                         ('air', 0), ('##w', 0), ('##a', 0),
                                         ('##y', 0), ('##s', 0), ('7', 4),
                                         ('canada', 0), (',', 0), ('to', 0),
                                         ('##ro', 0), ('##nt', 0), ('##o', 0),
                                         ('1', 4), (',', 4), ('20', 4),
                                         ('##2', 4), ('1', 0), ('air', 0),
                                         ('t', 0), ('##ra', 0), ('##ns', 0),
                                         ('##a', 0), ('##t', 0), (',', 0),
                                         ('can', 0), ('##j', 0), ('##et', 0),
                                         ('8', 4), ('canada', 0), (',', 0),
                                         ('ed', 0), ('##m', 0), ('##on', 0),
                                         ('##ton', 0), ('11', 4), ('##0', 4),
                                         ('[EMPTY]', 0), ('[EMPTY]', 0),
                                         ('9', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('o', 0),
                                         ('##a', 0), ('##k', 0), ('##land', 0),
                                         ('10', 4), ('##7', 4), ('[EMPTY]', 0),
                                         ('[EMPTY]', 0)])
示例#4
0
    def test_gracefully_handle_big_examples(self, max_seq_length,
                                            max_column_id, max_row_id,
                                            expected_counters, mock_read):

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_02.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=self._vocab_path,
                max_seq_length=60
                if max_seq_length is None else max_seq_length,
                max_column_id=5 if max_column_id is None else max_column_id,
                max_row_id=10 if max_row_id is None else max_row_id,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        if max_seq_length is None and max_column_id is None and max_row_id is None:
            output = _read_examples(self._output_path)

            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02.pbtxt')) as input_file:
                expected_example = text_format.ParseLines(
                    input_file, tf.train.Example())
            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02_conv.pbtxt')) as input_file:
                expected_conversational_example = text_format.ParseLines(
                    input_file, tf.train.Example())

            self.assertLen(output, 2)

            actual_example = output[0]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_example)

            actual_example = output[1]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_conversational_example)