def __init__(self, path, element_spec, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path self._element_spec = element_spec self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor)
def __init__(self, path, element_spec=None, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path if element_spec is None: with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f: encoded_spec = f.read() struct_pb = nested_structure_coder.struct_pb2.StructuredValue() struct_pb.ParseFromString(encoded_spec) coder = nested_structure_coder.StructureCoder() spec = coder.decode_proto(struct_pb) self._element_spec = spec else: self._element_spec = element_spec self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(self._element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor)
def testDatasetSpecHierarchical(self): spec_1 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32), [5, None, 2]) spec_2 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(None, None), dtype=dtypes.int32), [None, None, None]) spec_3 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(1, 2), dtype=dtypes.int32), [5, 3, 2]) spec_4 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(None, 2), dtype=dtypes.int32), [None, 1, None]) self.assertTrue(spec_1.is_subtype_of(spec_1)) self.assertTrue(spec_1.is_subtype_of(spec_2)) self.assertTrue(spec_3.is_subtype_of(spec_2)) self.assertTrue(spec_4.is_subtype_of(spec_2)) self.assertFalse(spec_2.is_subtype_of(spec_1)) self.assertFalse(spec_2.is_subtype_of(spec_3)) self.assertFalse(spec_2.is_subtype_of(spec_4)) self.assertEqual(spec_1.most_specific_common_supertype([]), spec_1) self.assertEqual(spec_1.most_specific_common_supertype([spec_4]), spec_2) self.assertEqual( spec_1.most_specific_common_supertype([spec_3, spec_4]), spec_2) self.assertEqual( spec_1.most_specific_common_supertype([spec_2, spec_3, spec_4]), spec_2)
def __init__(self, input_dataset, path, shard_func, compression=None, reader_func=None, pending_snapshot_expiry_seconds=None, use_legacy_function=False): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._input_dataset = input_dataset self._path = path self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, self._transformation_name() + ".reader_func", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(input_dataset.element_spec)), use_legacy_function=use_legacy_function) self._shard_func = dataset_ops.StructuredFunctionWrapper( shard_func, self._transformation_name() + ".shard_func", dataset=input_dataset, use_legacy_function=use_legacy_function) if ((not self._shard_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int32))) and (not self._shard_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int64)))): raise TypeError( "shard_func must return a 0-dimension tensor containing an int." ) variant_tensor = ged_ops.snapshot_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access path, self._reader_func.function.captured_inputs, self._shard_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, shard_func=self._shard_func.function, **self._flat_structure) super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
def testDatasetSpecConstructor(self): rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) self.assertEqual(ds_struct._element_spec, element_spec) # Note: shape was automatically converted from a list to a TensorShape. self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
def testEncodeDataSetSpec(self): structure = [dataset_ops.DatasetSpec( {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), "t": tensor_spec.TensorSpec([10, 8], dtypes.string)})] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testDatasetSpecTraceType(self): trace_type_1 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), [5]) trace_type_2 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), [5]) self.assertEqual(trace_type_1, trace_type_2) self.assertEqual(hash(trace_type_1), hash(trace_type_2)) self.assertTrue(trace_type_1.is_subtype_of(trace_type_2)) self.assertTrue(trace_type_2.is_subtype_of(trace_type_1)) trace_type_3 = dataset_ops.DatasetSpec( tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), [6]) self.assertNotEqual(trace_type_1, trace_type_3) self.assertFalse(trace_type_1.is_subtype_of(trace_type_3)) self.assertFalse(trace_type_3.is_subtype_of(trace_type_1))
def testInputSignature(self): dataset = dataset_ops.Dataset.from_tensor_slices( np.arange(10).astype(np.int32)).batch(5) @def_function.function(input_signature=[ dataset_ops.DatasetSpec( tensor_spec.TensorSpec( shape=(None,), dtype=dtypes.int32, name=None), tensor_shape.TensorShape([])) ]) def fn(_): pass fn(dataset)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" nested_dataset = dataset_ops.DatasetSpec(input_dataset.element_spec) input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset) self._reduce_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=input_structure) if not isinstance(self._reduce_func.output_structure, dataset_ops.DatasetSpec): raise TypeError("`reduce_func` must return a `Dataset` object.") # pylint: disable=protected-access self._element_spec = (self._reduce_func.output_structure._element_spec)
def testDatasetSpecHierarchicalDict(self): spec_1 = dataset_ops.DatasetSpec( {"a": tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32)}, []) spec_2 = dataset_ops.DatasetSpec( {"a": tensor_spec.TensorSpec(shape=(None, None), dtype=dtypes.int32)}, []) spec_3 = dataset_ops.DatasetSpec( {"b": tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32)}, []) spec_4 = dataset_ops.DatasetSpec({"b": None}, []) self.assertTrue(spec_1.is_subtype_of(spec_1)) self.assertTrue(spec_1.is_subtype_of(spec_2)) self.assertFalse(spec_2.is_subtype_of(spec_1)) self.assertFalse(spec_1.is_subtype_of(spec_3)) self.assertFalse(spec_3.is_subtype_of(spec_1)) self.assertFalse(spec_2.is_subtype_of(spec_3)) self.assertFalse(spec_3.is_subtype_of(spec_2)) self.assertTrue(spec_4.is_subtype_of(spec_4)) self.assertEqual(spec_4.most_specific_common_supertype([]), spec_4) self.assertEqual(spec_4.most_specific_common_supertype([spec_4]), spec_4)
def testDatasetDatasetSpec(self): self.checkDatasetSpec( dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
def __init__(self, input_dataset, functions, ratio_numerator=1, ratio_denominator=1, num_elements_per_branch=None): """Chooses the fastest of some dataset functions. Given dataset functions that take input_dataset as input and output another dataset, produces elements as quickly as the fastest of these output datasets. Note that datasets in the dataset functions are assumed to be stateless, and the iterators created by the functions' output datasets will, given the same input elements, all produce the same output elements. Datasets in the functions are also expected to iterate over the input dataset at most once. The violation of these conditions may lead to undefined behavior. For example: ```python dataset = tf.data.Dataset.range(100) dataset = _ChooseFastestDataset( dataset, [ lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) ], ratio=10, num_elements_per_branch=10 ) ``` The resulting dataset will produce elements equivalent to `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` Note that the first `num_elements_per_branch` iterations may be slower due to the overhead of dynamically picking the fastest dataset. Namely, for these iterations, the dataset will produce elements from any of branches to determine which input is the fastest. For all subsequent iterations, that input will be used. Args: input_dataset: A `Dataset` that can be used as input to `functions`. functions: A list of callables, each of which takes a `Dataset` as input and returns a `Dataset`. ratio_numerator: The numerator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_numerator should be 10. ratio_denominator: The denominator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_denominator should be 1. num_elements_per_branch: The number of elements to get from each branch before deciding which dataset is fastest. In the first len(functions) * num_elements_per_branch iterations, the dataset will call from one of the branches, and update its knowledge of which input is the fastest. Note that (num_elements_per_branch * ratio) is expected to be an integer. Returns: A `Dataset` that has the same elements the inputs. """ input_structure = dataset_ops.DatasetSpec(input_dataset.element_spec) self._funcs = [ dataset_ops.StructuredFunctionWrapper( f, "ChooseFastestV2", input_structure=input_structure) for f in functions ] self._element_spec = self._funcs[0].output_structure._element_spec # pylint: disable=protected-access self._captured_arguments = [] for f in self._funcs: self._captured_arguments.extend(f.function.captured_inputs) self._capture_lengths = [ len(f.function.captured_inputs) for f in self._funcs ] if ratio_numerator <= 0 or ratio_denominator <= 0: raise ValueError("ratio must be positive.") if num_elements_per_branch is None: # Pick a sensible default based on `ratio_denominator` num_elements_per_branch = 10 * ratio_denominator variant_tensor = ( gen_experimental_dataset_ops.choose_fastest_branch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access ratio_numerator=ratio_numerator, ratio_denominator=ratio_denominator, other_arguments=self._captured_arguments, num_elements_per_branch=num_elements_per_branch, branches=[f.function for f in self._funcs], other_arguments_lengths=self._capture_lengths, **self._flat_structure)) super(_ChooseFastestBranchDataset, self).__init__(input_dataset, variant_tensor)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map( lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) def testAsFunctionWithMapInFlatMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).flat_map( lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = dataset_ops._VariantDataset( variant, original_dataset.element_spec) self.assertDatasetProduces(revived_dataset, list(original_dataset)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertLen(dataset_fn()._inputs(), num_inputs) @test_util.run_v1_only("deprecated API, no eager or V2 test coverage") def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2 ], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None ], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testNoWarnings(self): with test.mock.patch.object(warnings, "warn") as mock_log: dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10)) dataset_fn(dataset_ops.Dataset.range(10)) self.assertEmpty(mock_log.call_args_list) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testFunctions(self): dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) self.assertLen(dataset._functions(), 1) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec([], dtypes.float32)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), sparse_tensor.SparseTensorSpec([1], dtypes.int32)), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, { "a": tensor_spec.TensorSpec([], dtypes.float32), "b": ( tensor_spec.TensorSpec([1], dtypes.string), tensor_spec.TensorSpec([], dtypes.string), ) }), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalSpec(tensor_spec.TensorSpec([], dtypes.float32))), ) def testDatasetSpec(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map( lambda _: tf_value_fn()) dataset_structure = structure.type_spec_from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(dataset), expected_element_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(dataset_structure)) self.assertEqual([tensor_shape.scalar()], structure.get_flat_tensor_shapes(dataset_structure)) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset_ops.make_one_shot_iterator(dataset) self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @parameterized.named_parameters( ("Async", context.ASYNC), ("Sync", context.SYNC), ) def testDatasetEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 def testDatasetAsFunctionArgument(self): @def_function.function def _uses_dataset(d): accumulator = array_ops.zeros([], dtype=dtypes.int64) for value in d: accumulator += value return accumulator with ops.device("CPU"): first_dataset = dataset_ops.Dataset.range(10) self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) second_dataset = dataset_ops.Dataset.range(11) self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) first_concrete = _uses_dataset.get_concrete_function(first_dataset) # The dataset should not be a captured input self.assertEmpty(first_concrete.graph.captures) # The two datasets have the same structure and so should re-use a trace. self.assertIs(first_concrete, _uses_dataset.get_concrete_function(second_dataset)) # With a different structure we should use a different trace. self.assertIsNot( first_concrete, _uses_dataset.get_concrete_function( dataset_ops.Dataset.zip((first_dataset, second_dataset)))) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(ds): trace_count[0] += 1 counter = np.int64(0) for elem in ds: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(dataset)), 10) self.assertEqual(self.evaluate(f(dataset2)), 45) self.assertEqual(trace_count[0], 1)
def testDatasetSpecInnerSpec(self): inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32) ds_spec = dataset_ops.DatasetSpec(inner_spec) self.assertEqual(ds_spec.element_spec, inner_spec)