def testTFRecordDatasetConstructorErrorsTensorInput(self): with self.assertRaisesRegex(TypeError, "filenames.*must be.*Tensor.*string"): readers.TFRecordDataset([1, 2, 3]) with self.assertRaisesRegex(TypeError, "filenames.*must be.*Tensor.*string"): readers.TFRecordDataset(constant_op.constant([1, 2, 3])) # convert_to_tensor raises different errors in graph and eager with self.assertRaises(Exception): readers.TFRecordDataset(object())
def testConstructorErrorsTensorInput(self): with self.assertRaisesRegex( TypeError, "The `filenames` argument must contain `tf.string` elements. Got " "`tf.int32` elements."): readers.TFRecordDataset([1, 2, 3]) with self.assertRaisesRegex( TypeError, "The `filenames` argument must contain `tf.string` elements. Got " "`tf.int32` elements."): readers.TFRecordDataset(constant_op.constant([1, 2, 3])) # convert_to_tensor raises different errors in graph and eager with self.assertRaises(Exception): readers.TFRecordDataset(object())
def make_dataset(self, num_epochs, batch_size=1, compression_type=None, buffer_size=None): filenames = self._createFiles() if compression_type == "ZLIB": zlib_files = [] for i, fn in enumerate(filenames): with open(fn, "rb") as f: cdata = zlib.compress(f.read()) zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) with open(zfn, "wb") as f: f.write(cdata) zlib_files.append(zfn) filenames = zlib_files elif compression_type == "GZIP": gzip_files = [] for i, fn in enumerate(self._filenames): with open(fn, "rb") as f: gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) with gzip.GzipFile(gzfn, "wb") as gzf: gzf.write(f.read()) gzip_files.append(gzfn) filenames = gzip_files return readers.TFRecordDataset( filenames, compression_type, buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
def testName(self): files = [self._filenames[0]] expected_output = [self._record(0, i) for i in range(self._num_records)] ds = readers.TFRecordDataset(files, name="tf_record_dataset") self.assertDatasetProduces( ds, expected_output=expected_output, assert_items_equal=True)
def testPathlib(self): files = [pathlib.Path(self._filenames[0])] expected_output = [self._record(0, i) for i in range(self._num_records)] ds = readers.TFRecordDataset(files) self.assertDatasetProduces( ds, expected_output=expected_output, assert_items_equal=True)
def testReadFromDatasetOfFiles(self): files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) dataset = readers.TFRecordDataset(files) self.assertDatasetProduces(dataset, expected_output=expected_output)
def testDirectFilenameTFRecordReaderPipeline(self): dataset = core_readers.TFRecordDataset(self._filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def testZip(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) self._verifySimpleShardingOutput(dataset, record_fn)
def testReadWithBuffer(self): one_mebibyte = 2**20 dataset = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output)
def testReadTenEpochsFromDatasetOfFilesInParallel(self): files = dataset_ops.Dataset.from_tensor_slices( self.test_filenames).repeat(10) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) dataset = readers.TFRecordDataset(files, num_parallel_reads=4) self.assertDatasetProduces( dataset, expected_output=expected_output * 10, assert_items_equal=True)
def setUp(self): super(TFRecordWriterTest, self).setUp() self._num_records = 7 self.filename = array_ops.placeholder(dtypes.string, shape=[]) self.compression_type = array_ops.placeholder_with_default("", shape=[]) input_dataset = readers.TFRecordDataset([self.filename], self.compression_type) self.writer = writers.TFRecordWriter( self._outputFilename(), self.compression_type).write(input_dataset)
def testReadWithBuffer(self): one_mebibyte = 2**20 d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next())
def testReadFromDatasetOfFiles(self): files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames) d = readers.TFRecordDataset(files) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def _dataset_factory(self, filenames, compression_type="", num_epochs=1, batch_size=None): repeat_dataset = readers.TFRecordDataset( filenames, compression_type).repeat(num_epochs) if batch_size: return repeat_dataset.batch(batch_size) return repeat_dataset
def __init__(self, filenames, compression_type=None, buffer_size=None): """Creates a `TFRecordDataset`. Args: filenames: A `tf.string` tensor containing one or more filenames. compression_type: (Optional.) A `tf.string` scalar evaluating to one of `""` (no compression), `"ZLIB"`, or `"GZIP"`. buffer_size: (Optional.) A `tf.int64` scalar representing the number of bytes in the read buffer. 0 means no buffering. """ dataset = readers.TFRecordDataset(filenames, compression_type, buffer_size) super(TFRecordDataset, self).__init__(dataset)
def testAsFunctionFromReader(self): with ops.device("CPU"): file_path = os.path.join( self.get_temp_dir(), "{}.tfrecord.gz".format("tf_record_asset")) with tf_record.TFRecordWriter(file_path, "GZIP") as f: for v in ["a", "aa", "aaa"]: f.write(str(v)) original_dataset = readers.TFRecordDataset([file_path], compression_type="GZIP") fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, ["a", "aa", "aaa"])
def testReadWithEquivalentDataset(self): features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } dataset = (core_readers.TFRecordDataset( self._filenames).map(lambda x: parsing_ops.parse_single_example( x, features)).repeat(10).batch(2)) next_element = self.getNext(dataset) for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): actual_batch = self.evaluate(next_element()) self.assertAllEqual(file_batch, actual_batch["file"]) self.assertAllEqual(record_batch, actual_batch["record"]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element())
def testReadTenEpochsFromDatasetOfFilesInParallel(self): files = dataset_ops.Dataset.from_tensor_slices( self.test_filenames).repeat(10) d = readers.TFRecordDataset(files, num_parallel_reads=4) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() expected = [] actual = [] with self.test_session() as sess: for _ in range(10): for j in range(self._num_files): for i in range(self._num_records): expected.append(self._record(j, i)) actual.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) self.assertEqual(sorted(expected), sorted(actual))
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) next_element_fn = self._getNext(dataset) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._record(r, f), self.evaluate(next_element_fn())) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._text_line(r, f), self.evaluate(next_element_fn())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element_fn())
def setUp(self): super(TFRecordDatasetTestBase, self).setUp() self._num_files = 2 self._num_records = 7 self.test_filenames = self._createFiles() self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) self.num_epochs = array_ops.placeholder_with_default( constant_op.constant(1, dtypes.int64), shape=[]) self.compression_type = array_ops.placeholder_with_default("", shape=[]) self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) repeat_dataset = core_readers.TFRecordDataset( self.filenames, self.compression_type).repeat(self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) self.init_op = iterator.make_initializer(repeat_dataset) self.init_batch_op = iterator.make_initializer(batch_dataset) self.get_next = iterator.get_next()
def testReadWithEquivalentDataset(self): features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } dataset = (core_readers.TFRecordDataset(self.test_filenames).map( lambda x: parsing_ops.parse_single_example(x, features)).repeat( 10).batch(2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer next_element = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch in self._next_expected_batch( range(self._num_files), 2, 10): actual_batch = sess.run(next_element) self.assertAllEqual(file_batch, actual_batch["file"]) self.assertAllEqual(record_batch, actual_batch["record"]) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._record(r, f), sess.run(next_element)) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testShard(self): filename = self._createFile() dataset = readers.TFRecordDataset([filename]) def reduce_func(key, dataset): shard_filename = string_ops.string_join( [filename, string_ops.as_string(key)]) writer = writers.TFRecordWriter(shard_filename) writer.write(dataset.map(lambda _, x: x)) return dataset_ops.Dataset.from_tensors(shard_filename) dataset = dataset.enumerate() dataset = dataset.apply( grouping.group_by_window(lambda i, _: i % 2, reduce_func, dtypes.int64.max)) get_next = self.getNext(dataset) for i in range(2): shard_filename = (filename + str(i)).encode() self.assertEqual(self.evaluate(get_next()), shard_filename) for j, r in enumerate( tf_record.tf_record_iterator(shard_filename)): self.assertAllEqual(self._record(i + 2 * j), r)
def testTFRecordDatasetIgnoreError(self): filenames = [] for i in range(5): fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) filenames.append(fn) writer = python_io.TFRecordWriter(fn) for _ in range(10): writer.write(b"record") writer.close() # Append corrupted data with open(fn, "a") as f: f.write("corrupted data") dataset = readers.TFRecordDataset(filenames).apply( error_ops.ignore_errors()) get_next = self.getNext(dataset) # All of the files are present. for _ in filenames: for _ in range(10): self.assertEqual(b"record", self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset: buffer_size = 8 * 1024 * 1024 # 8 MiB per file dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) return dataset
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # TODO(b/119882922): use-after-free bug in eager mode. # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testSkipEagerDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): _ = dataset.make_initializable_iterator()
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) def testAsFunctionWithMapInFlatMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).flat_map( lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, list(original_dataset)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertLen(dataset_fn()._inputs(), num_inputs) @test_util.run_v1_only("deprecated API, no eager or V2 test coverage") def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testNoWarnings(self): with test.mock.patch.object(warnings, "warn") as mock_log: dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10)) dataset_fn(dataset_ops.Dataset.range(10)) self.assertEmpty(mock_log.call_args_list) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testFunctions(self): dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) self.assertLen(dataset._functions(), 1) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset_ops.make_one_shot_iterator(dataset) self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @parameterized.named_parameters( ("Async", context.ASYNC), ("Sync", context.SYNC), ) def testDatasetEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 def testDatasetAsFunctionArgument(self): @def_function.function def _uses_dataset(d): accumulator = array_ops.zeros([], dtype=dtypes.int64) for value in d: accumulator += value return accumulator with ops.device("CPU"): first_dataset = dataset_ops.Dataset.range(10) self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) second_dataset = dataset_ops.Dataset.range(11) self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) first_concrete = _uses_dataset.get_concrete_function(first_dataset) # The dataset should not be a captured input self.assertEmpty(first_concrete.graph.captures) # The two datasets have the same structure and so should re-use a trace. self.assertIs(first_concrete, _uses_dataset.get_concrete_function(second_dataset)) # With a different structure we should use a different trace. self.assertIsNot( first_concrete, _uses_dataset.get_concrete_function( dataset_ops.Dataset.zip((first_dataset, second_dataset)))) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(ds): trace_count[0] += 1 counter = np.int64(0) for elem in ds: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(dataset)), 10) self.assertEqual(self.evaluate(f(dataset2)), 45) self.assertEqual(trace_count[0], 1)
def testTFRecordInputs(self): dataset = readers.TFRecordDataset("") self.checkNumInputs(dataset, 1)
def _TFRecordDataset(filename): buffer_size = 8 * 1024 * 1024 # 8 MiB per file dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) return dataset
def make_tf_record_dataset(file_pattern, batch_size, parser_fn=None, num_epochs=None, shuffle=True, shuffle_buffer_size=None, shuffle_seed=None, prefetch_buffer_size=optimization.AUTOTUNE, num_parallel_reads=None, num_parallel_parser_calls=None, drop_final_batch=False): """Reads and optionally parses TFRecord files into a dataset. Provides common functionality such as batching, optional parsing, shuffling, and performant defaults. Args: file_pattern: List of files or patterns of TFRecord file paths. See `tf.io.gfile.glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. parser_fn: (Optional.) A function accepting string input to parse and process the record contents. This function must map records to components of a fixed shape, so they may be batched. By default, uses the record contents unmodified. num_epochs: (Optional.) An int specifying the number of times this dataset is repeated. If None (the default), cycles through the dataset forever. shuffle: (Optional.) A bool that indicates whether the input should be shuffled. Defaults to `True`. shuffle_buffer_size: (Optional.) Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. shuffle_seed: (Optional.) Randomization seed to use for shuffling. prefetch_buffer_size: (Optional.) An int specifying the number of feature batches to prefetch for performance improvement. Defaults to auto-tune. Set to 0 to disable prefetching. num_parallel_reads: (Optional.) Number of threads used to read records from files. By default or if set to a value >1, the results will be interleaved. num_parallel_parser_calls: (Optional.) Number of parallel records to parse in parallel. Defaults to an automatic selection. drop_final_batch: (Optional.) Whether the last batch should be dropped in case its size is smaller than `batch_size`; the default behavior is not to drop the smaller batch. Returns: A dataset, where each element matches the output of `parser_fn` except it will have an additional leading `batch-size` dimension, or a `batch_size`-length 1-D tensor of strings if `parser_fn` is unspecified. """ files = dataset_ops.Dataset.list_files( file_pattern, shuffle=shuffle, seed=shuffle_seed) if num_parallel_reads is None: # Note: We considered auto-tuning this value, but there is a concern # that this affects the mixing of records from different files, which # could affect training convergence/accuracy, so we are defaulting to # a constant for now. num_parallel_reads = 24 dataset = core_readers.TFRecordDataset( files, num_parallel_reads=num_parallel_reads) if shuffle_buffer_size is None: # TODO(josh11b): Auto-tune this value when not specified shuffle_buffer_size = 10000 dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to # improve the shape inference, because it makes the batch dimension static. # It is safe to do this because in that case we are repeating the input # indefinitely, and all batches will be full-sized. drop_final_batch = drop_final_batch or num_epochs is None if parser_fn is None: dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) else: # TODO(josh11b): if num_parallel_parser_calls is None, use some function # of num cores instead of map_and_batch's default behavior of one batch. dataset = dataset.apply(batching.map_and_batch( parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, drop_remainder=drop_final_batch)) if prefetch_buffer_size == 0: return dataset else: return dataset.prefetch(buffer_size=prefetch_buffer_size)