Ejemplo n.º 1
0
 def _set_up_test_pipeline_pb(self):
   """Read expected pipeline pb from a text proto file."""
   test_pb_filepath = os.path.join(
       os.path.dirname(__file__), "testdata", "iris_pipeline_ir.pbtxt")
   with open(test_pb_filepath) as text_pb_file:
     self._pipeline_pb = text_format.ParseLines(text_pb_file,
                                                pipeline_pb2.Pipeline())
Ejemplo n.º 2
0
    def __init__(self, config_file):
        """
        Construct a DeepLab2 model.

        Parameters
        ----------
        config_file: str
            Path to a text protobuf config file, as used by deeplab.
        """
        if isinstance(config_file, str):
            with tf.io.gfile.GFile(config_file, 'r') as proto_file:
                self._config_text = proto_file.readlines()
        else:
            self._config_text = config_file['config_file']
        self._config = text_format.ParseLines(self._config_text, config_pb2.ExperimentOptions())
        # not used except ignore_label
        ds = dataset.DatasetDescriptor(
            dataset_name=None,
            splits_to_sizes=None,
            num_classes=None,
            ignore_label=None,
            panoptic_label_divisor=None,
            class_has_instances_list=None,
            is_video_dataset=None,
            colormap=None,
            is_depth_dataset=None,
            ignore_depth=None,
        )
        super().__init__(self._config, ds)
Ejemplo n.º 3
0
 def _get_test_pipeline_pb(self, file_name: str) -> pipeline_pb2.Pipeline:
     """Reads expected pipeline pb from a text proto file."""
     test_pb_filepath = os.path.join(os.path.dirname(__file__), "testdata",
                                     file_name)
     with open(test_pb_filepath) as text_pb_file:
         return text_format.ParseLines(text_pb_file,
                                       pipeline_pb2.Pipeline())
Ejemplo n.º 4
0
 def testProtoOperatorDescriptor(self):
     test_pb_filepath = os.path.join(os.path.dirname(__file__), 'testdata',
                                     'proto_placeholder_operator.pbtxt')
     with open(test_pb_filepath) as text_pb_file:
         expected_pb = text_format.ParseLines(
             text_pb_file, placeholder_pb2.PlaceholderExpression())
     placeholder = ph.exec_property('splits_config').analyze[0]
     component_spec = standard_component_specs.TransformSpec
     self.assertProtoEquals(placeholder.encode(component_spec), expected_pb)
  def testParseLinesGolden(self):
    opened = self.ReadGolden('text_format_unittest_data.txt')
    parsed_message = unittest_pb2.TestAllTypes()
    r = text_format.ParseLines(opened, parsed_message)
    self.assertIs(r, parsed_message)

    message = unittest_pb2.TestAllTypes()
    test_util.SetAllFields(message)
    self.assertEquals(message, parsed_message)
Ejemplo n.º 6
0
 def testProtoFutureValueOperator(self):
     test_pb_filepath = os.path.join(
         os.path.dirname(__file__), 'testdata',
         'proto_placeholder_future_value_operator.pbtxt')
     with open(test_pb_filepath) as text_pb_file:
         expected_pb = text_format.ParseLines(
             text_pb_file, placeholder_pb2.PlaceholderExpression())
     output_channel = Channel(type=standard_artifacts.Integer)
     placeholder = output_channel.future()[0].value
     placeholder._key = '_component.num'
     self.assertProtoEquals(placeholder.encode(), expected_pb)
Ejemplo n.º 7
0
    def test_candidates(self, mock_read):

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_03.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])
        max_seq_length = 15

        tf_example_utils._MAX_NUM_ROWS = 4
        tf_example_utils._MAX_NUM_CANDIDATES = 10

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=4,
                max_row_id=4,
                strip_column_names=False,
                add_aggregation_candidates=True,
            ),
        )

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        output = _read_examples(self._output_path)

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'tf_example_03.pbtxt')) as input_file:
            expected_example = text_format.ParseLines(input_file,
                                                      tf.train.Example())

        actual_example = output[0]
        logging.info('%s', actual_example)
        # assertEqual struggles with NaNs inside protos
        del actual_example.features.feature['numeric_values']
        self.assertEqual(actual_example, expected_example)
  def test_process_doc(self, version):
    interactions = list(
        sem_tab_fact_utils._process_doc(
            os.path.join(
                self.test_data_dir,
                'sem_tab_fact_20502.xml',
            ),
            version,
        ))

    if version == sem_tab_fact_utils.Version.V1:
      name = 'sem_tab_fact_20502_interaction.txtpb'
    elif version == sem_tab_fact_utils.Version.V2:
      name = 'sem_tab_fact_20502_interaction_v2.txtpb'
    else:
      raise ValueError(f'Unsupported version: {version.name}')
    interaction_file = os.path.join(self.test_data_dir, name)
    with open(interaction_file) as input_file:
      interaction = text_format.ParseLines(input_file,
                                           interaction_pb2.Interaction())
    self.assertLen(interactions, 4)
    logging.info(interactions[0])
    self.assertEqual(interactions[0], interaction)
    questions = [
        (  # pylint: disable=g-complex-comprehension
            i.questions[0].id,
            i.questions[0].original_text,
            i.questions[0].answer.class_index,
        ) for i in interactions
    ]
    self.assertEqual(questions, [
        (
            'sem_tab_fact_20502_Table_2_2_0',
            'At the same time, these networks often occur in tandem at the firm level.',
            1,
        ),
        (
            'sem_tab_fact_20502_Table_2_3_0',
            'For each network interaction, there is considerable variation both across and within countries.',
            1,
        ),
        (
            'sem_tab_fact_20502_Table_2_5_0',
            'The n value is same for Hong Kong and Malaysia.',
            0,
        ),
        (
            'sem_tab_fact_20502_Table_2_8_0',
            'There are 9 different types country in the given table.',
            1,
        ),
    ])
Ejemplo n.º 9
0
    def test_end_to_end(self, input_format, impl, mock_read):
        self._create_vocab(list(_RESERVED_SYMBOLS))

        with tf.gfile.Open(
                os.path.join(self._test_dir,
                             'retrieval_interaction.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())
        if input_format == _InputFormat.INTERACTION:
            samples = [interaction]
        elif input_format == _InputFormat.TABLE:
            samples = [interaction.table]
        else:
            raise ValueError(f'Unknown format: {input_format}')

        _set_mock_read(mock_read, samples)

        pipeline = create_data.build_retrieval_pipeline(
            input_files=['input.tfrecord'],
            input_format=input_format,
            output_files=[self._output_path],
            config=retrieval_utils.RetrievalConversionConfig(
                vocab_file=self._vocab_path,
                max_seq_length=15,
                max_column_id=5,
                max_row_id=5,
                strip_column_names=False),
            converter_impl=impl,
        )
        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()
        counters = {
            metric_result.key.metric.name: metric_result.committed
            for metric_result in result.metrics().query()['counters']
        }

        if input_format == _InputFormat.INTERACTION:
            self.assertEqual(counters, {
                'Input question': 1,
                'Conversion success': 1,
            })
        else:
            self.assertEqual(
                counters, {
                    'Input question': 1,
                    'Conversion success': 1,
                    'Fake Questions added for table only example': 1,
                })

        output = _read_examples(self._output_path)
        self.assertLen(output, 1)
Ejemplo n.º 10
0
def create_selector(
    table_pruning_config_file,
    vocab_size,
    hidden_size,
    initializer_range,
    max_num_columns,
    max_num_rows,
    type_vocab_size,
    disabled_features,
    disable_position_embeddings,
    max_position_embeddings,
):
    """Activates the scoring model according to table pruning config."""
    if not table_pruning_config_file:
        return NoTablePruning()
    config = table_pruning_pb2.TablePruningModel()
    with tf.gfile.Open(table_pruning_config_file) as input_file:
        # ParseLines
        config = text_format.ParseLines(input_file,
                                        table_pruning_pb2.TablePruningModel())
    model = config.WhichOneof("table_pruning_model")
    max_num_tokens = config.max_num_tokens
    if model == "avg_cos_similarity":
        return AverageCosineSimilaritySelector(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            initializer_range=initializer_range,
            max_num_columns=max_num_columns,
            type_vocab_size=type_vocab_size,
            disabled_features=disabled_features,
            disable_position_embeddings=disable_position_embeddings,
            max_position_embeddings=max_position_embeddings,
            config=config.avg_cos_similarity,
            max_num_tokens=max_num_tokens)
    elif model == "tapas":
        return TapasPruningSelector(config=config.tapas,
                                    max_num_columns=max_num_columns,
                                    max_num_tokens=max_num_tokens,
                                    max_num_rows=max_num_rows)
    elif model == "first_tokens":
        return OnesTablePruning(config=config.first_tokens,
                                max_num_columns=max_num_columns,
                                max_num_tokens=max_num_tokens)
    else:
        raise NotImplementedError(f"TablePruningModel not implemented {model}")
Ejemplo n.º 11
0
    def testMapOrderSemantics(self):
        golden_lines = self.ReadGolden('map_test_data.txt')
        # The C++ implementation emits defaulted-value fields, while the Python
        # implementation does not.  Adjusting for this is awkward, but it is
        # valuable to test against a common golden file.
        line_blacklist = ('  key: 0\n', '  value: 0\n', '  key: false\n',
                          '  value: false\n')
        golden_lines = [
            line for line in golden_lines if line not in line_blacklist
        ]

        message = map_unittest_pb2.TestMap()
        text_format.ParseLines(golden_lines, message)
        candidate = text_format.MessageToString(message)
        # The Python implementation emits "1.0" for the double value that the C++
        # implementation emits as "1".
        candidate = candidate.replace('1.0', '1', 2)
        self.assertMultiLineEqual(candidate, ''.join(golden_lines))
Ejemplo n.º 12
0
    def test_get_empty_example(self):
        max_seq_length = 15

        input_path = os.path.join(self.test_data_dir,
                                  'retrieval_interaction.pbtxt')
        with open(input_path) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())
        number_annotation_utils.add_numeric_values(interaction)

        with tempfile.TemporaryDirectory() as input_dir:
            vocab_file = os.path.join(input_dir, 'vocab.txt')
            _create_vocab(vocab_file, [
                'by', 'created', 'do', 'dragon', 'go', 'hannibal', 'harris',
                'in', 'lecter', 'movies', 'novels', 'order', 'original', 'red',
                'the', 'thomas', 'what', 'work'
            ])
            converter = tf_example_utils.ToRetrievalTensorflowExample(
                config=tf_example_utils.RetrievalConversionConfig(
                    vocab_file=vocab_file,
                    max_seq_length=max_seq_length,
                    max_column_id=max_seq_length,
                    max_row_id=max_seq_length,
                    strip_column_names=False,
                ))
            example = converter.convert(interaction,
                                        index=0,
                                        negative_example=None)
            logging.info(example)
            # Check the question.
            self.assertEqual(
                _get_int_feature(example, 'question_input_ids'),
                [2, 22, 17, 8, 20, 11, 14, 15, 10, 13, 3, 0, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'question_input_mask'),
                             [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0])
            # Check document title + table.
            self.assertEqual(
                _get_int_feature(example, 'input_ids'),
                [2, 11, 14, 3, 7, 6, 18, 23, 16, 21, 12, 19, 9, 19, 9])
            self.assertEqual(_get_int_feature(example, 'input_mask'),
                             [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
            self.assertEqual(_get_int_feature(example, 'segment_ids'),
                             [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Ejemplo n.º 13
0
    def test_tfrecord_io(self, mock_read):
        """Reads from TFRecord and writes to TFRecord."""

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_03.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        def dummy_read(file_pattern, coder, validate):
            del file_pattern, coder, validate  # Unused.
            return beam.Create([interaction])

        mock_read.side_effect = dummy_read
        max_seq_length = 15

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=4,
                max_row_id=4,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        output = []
        for value in tf.python_io.tf_record_iterator(self._output_path):
            example = tf.train.Example()
            example.ParseFromString(value)
            output.append(example)

        self.assertLen(output, 1)
        sid = output[0].features.feature['segment_ids']
        self.assertLen(sid.int64_list.value, max_seq_length)
Ejemplo n.º 14
0
def main(argv):
    if len(argv) > 3:
        raise app.UsageError('Too many command-line arguments.')

    # Read in the results file to see what configs to test.
    results = lc_pb2.LecTiming()
    if FLAGS.results_path and gfile.exists(FLAGS.results_path):
        with gfile.open(FLAGS.results_path, 'r') as fd:
            results = text_format.ParseLines(fd, lc_pb2.LecTiming())

    with gfile.open(FLAGS.cell_library_textproto_path, 'r') as fd:
        cell_library_textproto = fd.read()

    lc = lc_mod.LecCharacterizer(FLAGS.synthesis_server_address)

    for width in FLAGS.widths:
        bits_type = xls_type_pb2.TypeProto(
            type_enum=xls_type_pb2.TypeProto.BITS, bit_count=int(width))

        function_type = xls_type_pb2.FunctionTypeProto()
        function_type.parameters.add().CopyFrom(bits_type)
        function_type.parameters.add().CopyFrom(bits_type)
        function_type.return_type.CopyFrom(bits_type)

        test_case = None
        for result_case in results.test_cases:
            # Find or create a matching test case for this function type.
            if result_case.function_type == function_type:
                test_case = result_case

        if test_case is None:
            test_case = results.test_cases.add()
            test_case.function_type.CopyFrom(function_type)

        runs_left = FLAGS.runs_per_type - len(test_case.exec_times_us)
        if runs_left > 0:
            lc.run(results, op_pb2.OpProto.Value(FLAGS.op), function_type,
                   int(runs_left), cell_library_textproto, z3_lec.run,
                   _save_results)
Ejemplo n.º 15
0
    def test_numeric_relations(self, mock_read):
        input_file = 'interaction_00.pbtxt'
        expected_counters = {
            'Conversion success': 1,
            'Example emitted': 1,
            'Input question': 1,
            'Relation Set Index: 2': 5,
            'Relation Set Index: 4': 13,
            'Found answers: <= 4': 1,
        }

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        input_file)) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        max_seq_length = 512

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=512,
                max_row_id=512,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        output = _read_examples(self._output_path)

        self.assertLen(output, 1)
        actual_example = output[0]

        self.assertIn('numeric_relations',
                      actual_example.features.feature.keys())
        relations = actual_example.features.feature[
            'numeric_relations'].int64_list.value

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'vocab.txt')) as vocab_file:
            vocab = [line.strip() for line in vocab_file]
        inputs = actual_example.features.feature['input_ids'].int64_list.value
        pairs = [(vocab[input_id], relation)
                 for (input_id, relation) in zip(inputs, relations)
                 if input_id > 0]
        logging.info('pairs: %s', pairs)
        self.assertSequenceEqual(pairs, [('[CLS]', 0), ('which', 0),
                                         ('cities', 0),
                                         ('had', 0), ('less', 0), ('than', 0),
                                         ('2', 0), (',', 0), ('000', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('?', 0), ('[SEP]', 0),
                                         ('ran', 0), ('##k', 0), ('city', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('ran', 0), ('##ki', 0),
                                         ('##ng', 0), ('air', 0), ('##li', 0),
                                         ('##ne', 0), ('1', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('los', 0),
                                         ('angeles', 0), ('14', 2), (',', 2),
                                         ('7', 2), ('##4', 2), ('##9', 2),
                                         ('[EMPTY]', 0),
                                         ('al', 0), ('##as', 0), ('##ka', 0),
                                         ('air', 0), ('##li', 0), ('##ne', 0),
                                         ('##s', 0), ('2', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('h', 0),
                                         ('##ous', 0), ('##ton', 0), ('5', 2),
                                         (',', 2), ('4', 2), ('##6', 2),
                                         ('##5', 2), ('[EMPTY]', 0),
                                         ('united', 0), ('e', 0), ('##x', 0),
                                         ('##p', 0), ('##re', 0), ('##s', 0),
                                         ('##s', 0), ('3', 4), ('canada', 0),
                                         (',', 0), ('c', 0), ('##al', 0),
                                         ('##ga', 0), ('##ry', 0), ('3', 2),
                                         (',', 2), ('7', 2), ('##6', 2),
                                         ('##1', 2), ('[EMPTY]', 0),
                                         ('air', 0), ('t', 0), ('##ra', 0),
                                         ('##ns', 0), ('##a', 0), ('##t', 0),
                                         (',', 0), ('west', 0), ('##j', 0),
                                         ('##et', 0), ('4', 4), ('canada', 0),
                                         (',', 0), ('s', 0), ('##as', 0),
                                         ('##ka', 0), ('##to', 0), ('##on', 0),
                                         ('2', 2), (',', 2), ('28', 2),
                                         ('##2', 2), ('4', 0), ('[EMPTY]', 0),
                                         ('5', 4), ('canada', 0), (',', 0),
                                         ('van', 0), ('##co', 0), ('##u', 0),
                                         ('##ve', 0), ('##r', 0), ('2', 2),
                                         (',', 2), ('10', 2), ('##3', 2),
                                         ('[EMPTY]', 0), ('air', 0), ('t', 0),
                                         ('##ra', 0), ('##ns', 0), ('##a', 0),
                                         ('##t', 0), ('6', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('p', 0),
                                         ('##h', 0), ('##o', 0), ('##en', 0),
                                         ('##i', 0), ('##x', 0), ('1', 4),
                                         (',', 4), ('8', 4), ('##2', 4),
                                         ('##9', 4), ('1', 0), ('us', 0),
                                         ('air', 0), ('##w', 0), ('##a', 0),
                                         ('##y', 0), ('##s', 0), ('7', 4),
                                         ('canada', 0), (',', 0), ('to', 0),
                                         ('##ro', 0), ('##nt', 0), ('##o', 0),
                                         ('1', 4), (',', 4), ('20', 4),
                                         ('##2', 4), ('1', 0), ('air', 0),
                                         ('t', 0), ('##ra', 0), ('##ns', 0),
                                         ('##a', 0), ('##t', 0), (',', 0),
                                         ('can', 0), ('##j', 0), ('##et', 0),
                                         ('8', 4), ('canada', 0), (',', 0),
                                         ('ed', 0), ('##m', 0), ('##on', 0),
                                         ('##ton', 0), ('11', 4), ('##0', 4),
                                         ('[EMPTY]', 0), ('[EMPTY]', 0),
                                         ('9', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('o', 0),
                                         ('##a', 0), ('##k', 0), ('##land', 0),
                                         ('10', 4), ('##7', 4), ('[EMPTY]', 0),
                                         ('[EMPTY]', 0)])
Ejemplo n.º 16
0
    def test_gracefully_handle_big_examples(self, max_seq_length,
                                            max_column_id, max_row_id,
                                            expected_counters, mock_read):

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_02.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=self._vocab_path,
                max_seq_length=60
                if max_seq_length is None else max_seq_length,
                max_column_id=5 if max_column_id is None else max_column_id,
                max_row_id=10 if max_row_id is None else max_row_id,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        if max_seq_length is None and max_column_id is None and max_row_id is None:
            output = _read_examples(self._output_path)

            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02.pbtxt')) as input_file:
                expected_example = text_format.ParseLines(
                    input_file, tf.train.Example())
            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02_conv.pbtxt')) as input_file:
                expected_conversational_example = text_format.ParseLines(
                    input_file, tf.train.Example())

            self.assertLen(output, 2)

            actual_example = output[0]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_example)

            actual_example = output[1]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_conversational_example)
Ejemplo n.º 17
0
    def test_end_to_end_multiple_interactions(self, mock_read):
        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_01.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        interactions = []
        for trial in range(100):
            table_id = f'table_id_{trial}'
            new_interaction = interaction_pb2.Interaction()
            new_interaction.CopyFrom(interaction)
            new_interaction.table.table_id = table_id
            new_interaction.id = table_id
            interactions.append(new_interaction)

        _set_mock_read(mock_read, interactions)

        self._create_vocab(
            list(_RESERVED_SYMBOLS) + list(string.ascii_lowercase) +
            ['##' + letter for letter in string.ascii_lowercase])

        pipeline = create_data.build_pretraining_pipeline(
            input_file='input.tfrecord',
            output_suffix='.tfrecord',
            output_dir=self._output_path,
            config=_PretrainConfig(
                vocab_file=self._vocab_path,
                max_seq_length=40,
                max_predictions_per_seq=10,
                random_seed=5,
                masked_lm_prob=0.5,
                max_column_id=5,
                max_row_id=5,
                min_question_length=5,
                max_question_length=10,
                always_continue_cells=True,
                strip_column_names=False,
            ),
            dupe_factor=1,
            min_num_columns=0,
            min_num_rows=0,
            num_random_table_bins=10,
            num_corpus_bins=
            100000,  # High number sends all examples to train set.
            add_random_table=True,
        )

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        counters = {
            metric_result.key.metric.name: metric_result.committed
            for metric_result in result.metrics().query()['counters']
        }

        self.assertEqual(
            counters, {
                'Examples': 100,
                'Examples with tables': 100,
                'Interactions': 100,
                'Interactions without random interaction': 11,
                'Question Length: < inf': 31,
                'Question Length: <= 10': 53,
                'Question Length: <= 7': 16,
                'Real Table Size: <= 8': 100,
                'Trimmed Table Size: <= 8': 100,
                'Column Sizes: <= 8': 100,
                'Row Sizes: <= 8': 100,
                'Table Token Sizes: <= 8': 100,
                'Inputs': 100,
            })

        output = _read_examples(
            os.path.join(self._output_path, 'train.tfrecord'))
        self.assertLen(output, 100)
Ejemplo n.º 18
0
def read_textproto(path: str) -> citylex_pb2.Lexicon:
    """Parses textproto."""
    lexicon = citylex_pb2.Lexicon()
    with open(path, "r") as source:
        text_format.ParseLines(source, lexicon)
    return lexicon