示例#1
0
def input_fn(
    name,
    file_patterns,
    data_format,
    compression_type,
    is_training,
    max_seq_length,
    max_predictions_per_seq,
    params,
):
    """Returns an input_fn compatible with the tf.estimator API."""
    parse_example_fn = table_dataset.parse_table_examples(
        max_seq_length=max_seq_length,
        max_predictions_per_seq=max_predictions_per_seq,
        task_type=table_dataset.TableTask.PRETRAINING,
        add_aggregation_function_id=False,
        add_classification_labels=False,
        add_answer=False,
        include_id=False,
        add_candidate_answers=False,
        max_num_candidates=0,
        params=params)
    ds = dataset.read_dataset(parse_example_fn,
                              name=name,
                              file_patterns=file_patterns,
                              data_format=data_format,
                              compression_type=compression_type,
                              is_training=is_training,
                              params=params)
    return ds
def input_fn(name: Text, file_patterns: Iterable[Text], data_format: Text,
             is_training: bool, max_seq_length: int, include_id: bool,
             compression_type: Text, use_mined_negatives: bool, params):
    """Returns an input_fn compatible with the tf.estimator API."""
    task_type = (table_dataset.TableTask.RETRIEVAL_NEGATIVES
                 if use_mined_negatives else table_dataset.TableTask.RETRIEVAL)
    parse_example_fn = table_dataset.parse_table_examples(
        max_seq_length=max_seq_length,
        max_predictions_per_seq=None,
        task_type=task_type,
        add_aggregation_function_id=False,
        add_classification_labels=False,
        add_answer=False,
        include_id=include_id,
        add_candidate_answers=False,
        max_num_candidates=0,
        params=params)
    ds = dataset.read_dataset(
        parse_example_fn,
        name=name,
        file_patterns=file_patterns,
        data_format=data_format,
        is_training=is_training,
        compression_type=compression_type,
        params=dict(params, max_eval_count=None),
    )
    return ds
示例#3
0
    def test_read_dataset(self, data_format, is_training, params,
                          include_patterns):
        write_tf_example(
            self._file1, data_format, {
                "name":
                tf.train.Feature(bytes_list=tf.train.BytesList(
                    value=[b"one"])),
                "number":
                tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
            })
        write_tf_example(
            self._file2, data_format, {
                "name":
                tf.train.Feature(bytes_list=tf.train.BytesList(
                    value=[b"two"])),
                "number":
                tf.train.Feature(int64_list=tf.train.Int64List(value=[2])),
            })

        feature_types = {
            "name": tf.io.FixedLenFeature([], tf.string),
            "number": tf.io.FixedLenFeature([], tf.int64),
        }

        parse_fn = dataset.build_parser_function(feature_types, params)

        def filter_fn(xs):
            return [x for (x, include) in zip(xs, include_patterns) if include]

        patterns = filter_fn(self._file_patterns)
        ds = dataset.read_dataset(
            parse_fn,
            "dataset",
            patterns,
            data_format,
            compression_type="",
            is_training=is_training,
            params=params,
        )
        feature_tuple = tf.data.make_one_shot_iterator(ds).get_next()

        with self.test_session() as sess:
            feature_tuple = sess.run(feature_tuple)

        if params["batch_size"] == 1:
            self.assertIsInstance(feature_tuple, dict)
        else:
            self.assertLen(feature_tuple, params["batch_size"])

        if not is_training:
            expected_names = filter_fn([b"one", b"two"])
            expected_numbers = filter_fn([1, 2])
            self.assertSequenceEqual(list(feature_tuple["name"]),
                                     expected_names)
            self.assertSequenceEqual(list(feature_tuple["number"]),
                                     expected_numbers)
示例#4
0
 def test_read_dataset_test_shape_is_fully_known(self, data_format):
     write_tf_example(self._file1, data_format, {
         "number":
         tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
     })
     feature_types = {
         "number": tf.io.FixedLenFeature([], tf.int64),
     }
     params = {"batch_size": 5}
     parse_fn = dataset.build_parser_function(feature_types, params)
     ds = dataset.read_dataset(
         parse_fn,
         "dataset",
         file_patterns=[self._file1],
         data_format=data_format,
         compression_type="",
         is_training=True,
         params=params,
     )
     feature_tuple = tf.data.make_one_shot_iterator(ds).get_next()
     feature_tuple["number"].shape.assert_is_fully_defined()