def testFixedLengthReader(self): dataset = readers.FixedLengthRecordDataset( self._createFixedLengthRecordFiles(), self._record_bytes) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = _save_op(iterator._iterator_resource) restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op
def testFixedLengthReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createFixedLengthRecordFiles()) dataset = dataset.flat_map( lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def dataset_fn(filenames, num_epochs, batch_size=None): repeat_dataset = readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes, compression_type=compression_type).repeat(num_epochs) if batch_size: return repeat_dataset.batch(batch_size) return repeat_dataset
def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = self._save_op(iterator._iterator_resource) restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op
def testFixedLengthRecordDatasetBuffering(self): test_filenames = self._createFiles() dataset = readers.FixedLengthRecordDataset(test_filenames, self._record_bytes, self._header_bytes, self._footer_bytes, buffer_size=10) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output)
def _build_reader_dataset_graph(): filenames = ["test" ] # Does not exist but we don't care in this test. path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = gen_dataset_ops.save_iterator( iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, path) return init_op, get_next_op, save_op, restore_op
def testFixedLengthRecordDatasetParallelRead(self): test_filenames = self._createFiles() dataset = readers.FixedLengthRecordDataset(test_filenames, self._record_bytes, self._header_bytes, self._footer_bytes, buffer_size=10, num_parallel_reads=4) expected_output = [] for j in range(self._num_files): expected_output.extend( [self._record(j, i) for i in range(self._num_records)]) self.assertDatasetProduces(dataset, expected_output=expected_output, assert_items_equal=True)
def testFixedLengthRecordDatasetWrongSize(self): test_filenames = self._createFiles() dataset = readers.FixedLengthRecordDataset( test_filenames, self._record_bytes + 1, # Incorrect record length. self._header_bytes, self._footer_bytes, buffer_size=10) self.assertDatasetProduces( dataset, expected_error= (errors.InvalidArgumentError, r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input " r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, " r"which is not an exact multiple of the record length \(4 bytes\)." ))
def testFixedLengthRecordDatasetBuffering(self): test_filenames = self._createFiles() dataset = readers.FixedLengthRecordDataset( test_filenames, self._record_bytes, self._header_bytes, self._footer_bytes, buffer_size=10) iterator = dataset.make_one_shot_iterator() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next())
def testFixedLengthRecordDatasetWrongSize(self): test_filenames = self._createFiles() dataset = readers.FixedLengthRecordDataset( test_filenames, self._record_bytes + 1, # Incorrect record length. self._header_bytes, self._footer_bytes, buffer_size=10) iterator = dataset.make_one_shot_iterator() with self.test_session() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input " r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, " r"which is not an exact multiple of the record length \(4 bytes\)."): sess.run(iterator.get_next())
def __init__(self, filenames, record_bytes, header_bytes=None, footer_bytes=None, buffer_size=None): """Creates a `FixedLengthRecordDataset`. Args: filenames: A `tf.string` tensor containing one or more filenames. record_bytes: A `tf.int64` scalar representing the number of bytes in each record. header_bytes: (Optional.) A `tf.int64` scalar representing the number of bytes to skip at the start of a file. footer_bytes: (Optional.) A `tf.int64` scalar representing the number of bytes to ignore at the end of a file. buffer_size: (Optional.) A `tf.int64` scalar representing the number of bytes to buffer when reading. """ dataset = readers.FixedLengthRecordDataset(filenames, record_bytes, header_bytes, footer_bytes, buffer_size) super(FixedLengthRecordDataset, self).__init__(dataset)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # TODO(b/119882922): use-after-free bug in eager mode. # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testSkipEagerDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): _ = dataset.make_initializable_iterator()
class InputsTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32), 1), ("FromSparseTensorSlices", dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1])))), ("FromTensors", dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])), ("Range", dataset_ops.Dataset.range(10)), ("TextLine", readers.TextLineDataset("")), ("TFRecord", readers.TFRecordDataset(""), 1), ) def testDatasetSourceInputs(self, dataset, num_inputs=0): self.assertEqual(num_inputs, len(dataset._inputs())) @parameterized.named_parameters( ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), dataset_ops.Dataset.range(0)), ("ParallelInterleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2), dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset): self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1, input2): self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))) def testVariadicTransformationInputs(self, dataset_fn, input_datasets): self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3))
def testFixedLengthRecordInputs(self): dataset = readers.FixedLengthRecordDataset("", 42) self.checkNumInputs(dataset, 0)
def testFixedLengthRecordDataset(self): test_filenames = self._createFiles() filenames = array_ops.placeholder(dtypes.string, shape=[None]) num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) repeat_dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) batch_dataset = repeat_dataset.batch(batch_size) iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() with self.test_session() as sess: # Basic test: read from file 0. sess.run( init_op, feed_dict={filenames: [test_filenames[0]], num_epochs: 1}) for i in range(self._num_records): self.assertEqual(self._record(0, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from file 1. sess.run( init_op, feed_dict={filenames: [test_filenames[1]], num_epochs: 1}) for i in range(self._num_records): self.assertEqual(self._record(1, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from both files. sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test repeated iteration through both files. sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) for _ in range(10): for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test batched and repeated iteration through both files. sess.run( init_batch_op, feed_dict={ filenames: test_filenames, num_epochs: 10, batch_size: self._num_records }) for _ in range(10): for j in range(self._num_files): self.assertAllEqual( [self._record(j, i) for i in range(self._num_records)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) def testAsFunctionWithMapInFlatMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).flat_map( lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, list(original_dataset)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertLen(dataset_fn()._inputs(), num_inputs) @test_util.run_v1_only("deprecated API, no eager or V2 test coverage") def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testNoWarnings(self): with test.mock.patch.object(warnings, "warn") as mock_log: dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10)) dataset_fn(dataset_ops.Dataset.range(10)) self.assertEmpty(mock_log.call_args_list) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testFunctions(self): dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) self.assertLen(dataset._functions(), 1) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset_ops.make_one_shot_iterator(dataset) self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @parameterized.named_parameters( ("Async", context.ASYNC), ("Sync", context.SYNC), ) def testDatasetEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 def testDatasetAsFunctionArgument(self): @def_function.function def _uses_dataset(d): accumulator = array_ops.zeros([], dtype=dtypes.int64) for value in d: accumulator += value return accumulator with ops.device("CPU"): first_dataset = dataset_ops.Dataset.range(10) self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) second_dataset = dataset_ops.Dataset.range(11) self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) first_concrete = _uses_dataset.get_concrete_function(first_dataset) # The dataset should not be a captured input self.assertEmpty(first_concrete.graph.captures) # The two datasets have the same structure and so should re-use a trace. self.assertIs(first_concrete, _uses_dataset.get_concrete_function(second_dataset)) # With a different structure we should use a different trace. self.assertIsNot( first_concrete, _uses_dataset.get_concrete_function( dataset_ops.Dataset.zip((first_dataset, second_dataset)))) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(ds): trace_count[0] += 1 counter = np.int64(0) for elem in ds: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(dataset)), 10) self.assertEqual(self.evaluate(f(dataset2)), 45) self.assertEqual(trace_count[0], 1)
def FixedLengthFile(filename): return readers.FixedLengthRecordDataset(filename, record_bytes)
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) with self.cached_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) self.assertTrue( any(node.op != "RangeDataset" for node in graph.node)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetOpsTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2 ], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None ], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_autotune = True ds = dataset_ops.Dataset.range(0).with_options(options).with_options( options) self.assertEqual(options, ds.options()) def testOptionsTwiceDifferent(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_filter_fusion = False ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) self.assertTrue(ds.options().experimental_autotune) # Explicitly check that flag is False since assertFalse allows None self.assertIs(ds.options().experimental_filter_fusion, False) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_autotune = False with self.assertRaisesRegexp(ValueError, "Cannot merge incompatible values"): dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_filter_fusion = True ds = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(0).with_options(options1), dataset_ops.Dataset.range(0).with_options(options2))) self.assertTrue(ds.options().experimental_autotune) self.assertTrue(ds.options().experimental_filter_fusion) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map( lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue( expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue( dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True)
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) with self.cached_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) self.assertTrue( any([node.op != "RangeDataset" for node in graph.node])) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32), 1), ("FromSparseTensorSlices", dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1])))), ("FromTensors", dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])), ("Range", dataset_ops.Dataset.range(10)), ("TextLine", readers.TextLineDataset("")), ("TFRecord", readers.TFRecordDataset(""), 1), ) def testDatasetSourceInputs(self, dataset, num_inputs=0): self.assertEqual(num_inputs, len(dataset._inputs())) @parameterized.named_parameters( ("Apply", make_apply_fn.__func__( dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Interleave", make_interleave_fn.__func__( dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), dataset_ops.Dataset.range(0)), ("ParallelInterleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2), dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset): self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1, input2): self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))) def testVariadicTransformationInputs(self, dataset_fn, input_datasets): self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_autotune = True ds = dataset_ops.Dataset.range(0).with_options(options).with_options( options) self.assertEqual(options, ds.options()) def testOptionsTwiceDifferent(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_filter_fusion = False ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) self.assertTrue(ds.options().experimental_autotune) self.assertFalse(ds.options().experimental_filter_fusion) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_autotune = False with self.assertRaisesRegexp( ValueError, "Cannot merge incompatible values of option"): dataset_ops.Dataset.range(0).with_options(options1).with_options( options2)
def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() return core_readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs)