def input_fn( name, file_patterns, data_format, compression_type, is_training, max_seq_length, max_predictions_per_seq, params, ): """Returns an input_fn compatible with the tf.estimator API.""" parse_example_fn = table_dataset.parse_table_examples( max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, task_type=table_dataset.TableTask.PRETRAINING, add_aggregation_function_id=False, add_classification_labels=False, add_answer=False, include_id=False, add_candidate_answers=False, max_num_candidates=0, params=params) ds = dataset.read_dataset(parse_example_fn, name=name, file_patterns=file_patterns, data_format=data_format, compression_type=compression_type, is_training=is_training, params=params) return ds
def input_fn(name: Text, file_patterns: Iterable[Text], data_format: Text, is_training: bool, max_seq_length: int, include_id: bool, compression_type: Text, use_mined_negatives: bool, params): """Returns an input_fn compatible with the tf.estimator API.""" task_type = (table_dataset.TableTask.RETRIEVAL_NEGATIVES if use_mined_negatives else table_dataset.TableTask.RETRIEVAL) parse_example_fn = table_dataset.parse_table_examples( max_seq_length=max_seq_length, max_predictions_per_seq=None, task_type=task_type, add_aggregation_function_id=False, add_classification_labels=False, add_answer=False, include_id=include_id, add_candidate_answers=False, max_num_candidates=0, params=params) ds = dataset.read_dataset( parse_example_fn, name=name, file_patterns=file_patterns, data_format=data_format, is_training=is_training, compression_type=compression_type, params=dict(params, max_eval_count=None), ) return ds
def test_read_dataset(self, data_format, is_training, params, include_patterns): write_tf_example( self._file1, data_format, { "name": tf.train.Feature(bytes_list=tf.train.BytesList( value=[b"one"])), "number": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), }) write_tf_example( self._file2, data_format, { "name": tf.train.Feature(bytes_list=tf.train.BytesList( value=[b"two"])), "number": tf.train.Feature(int64_list=tf.train.Int64List(value=[2])), }) feature_types = { "name": tf.io.FixedLenFeature([], tf.string), "number": tf.io.FixedLenFeature([], tf.int64), } parse_fn = dataset.build_parser_function(feature_types, params) def filter_fn(xs): return [x for (x, include) in zip(xs, include_patterns) if include] patterns = filter_fn(self._file_patterns) ds = dataset.read_dataset( parse_fn, "dataset", patterns, data_format, compression_type="", is_training=is_training, params=params, ) feature_tuple = tf.data.make_one_shot_iterator(ds).get_next() with self.test_session() as sess: feature_tuple = sess.run(feature_tuple) if params["batch_size"] == 1: self.assertIsInstance(feature_tuple, dict) else: self.assertLen(feature_tuple, params["batch_size"]) if not is_training: expected_names = filter_fn([b"one", b"two"]) expected_numbers = filter_fn([1, 2]) self.assertSequenceEqual(list(feature_tuple["name"]), expected_names) self.assertSequenceEqual(list(feature_tuple["number"]), expected_numbers)
def test_read_dataset_test_shape_is_fully_known(self, data_format): write_tf_example(self._file1, data_format, { "number": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), }) feature_types = { "number": tf.io.FixedLenFeature([], tf.int64), } params = {"batch_size": 5} parse_fn = dataset.build_parser_function(feature_types, params) ds = dataset.read_dataset( parse_fn, "dataset", file_patterns=[self._file1], data_format=data_format, compression_type="", is_training=True, params=params, ) feature_tuple = tf.data.make_one_shot_iterator(ds).get_next() feature_tuple["number"].shape.assert_is_fully_defined()