def _weights_type_combinations(): return combinations.combine(weights_type=["list", "tensor", "dataset"])
def default_test_combinations(): """Returns the default test combinations for tf.data tests.""" return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: A target for connecting to the service, e.g. "grpc+local://localhost:2000". """ self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer(PROTOCOL, master_address=master_address)) return self._master.target @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): service = self.create_cluster(1) ds = dataset_ops.Dataset.range(3) ds = ds.apply(data_service_ops.distribute(service)) for _ in range(10): token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) self.assertEqual(list(range(3)), [t.numpy() for t in it]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeBasic(self): num_elements = 10 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) results = [t.numpy() for t in it] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): num_elements = 10 num_datasets = 3 service = self.create_cluster(1) iterators = [] results = [] for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) iterators.append(it) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): num_elements = 10 num_iterators = 3 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) result = [] iterators = [] token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") for _ in range(num_iterators): iterators.append(data_service_ops.create_iterator(ds, token)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 num_elements = 10 service = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) results = [elem.numpy() for elem in iterator] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 num_elements = 10 service = self.create_cluster(num_workers) @def_function.function def f(): ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in it: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) service = self.create_cluster(3) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) next(iterator) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testNoDistributeCalls(self): ds = dataset_ops.Dataset.range(1) with self.assertRaisesWithLiteralMatch( ValueError, "Dataset does not contain any distribute() transformations"): data_service_ops.create_job(ds, processing_mode="parallel_epochs") @combinations.generate(test_base.eager_only_combinations()) def testMultipleDistributeCalls(self): service = self.create_cluster(1) ds1 = dataset_ops.Dataset.range(1) ds1 = ds1.apply(data_service_ops.distribute(service)) ds2 = dataset_ops.Dataset.range(1) ds2 = ds2.apply(data_service_ops.distribute(service)) ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesWithLiteralMatch( ValueError, "Datasets containing multiple calls to .distribute(...) " "are not supported"): data_service_ops.create_job(ds, processing_mode="parallel_epochs") @combinations.generate(test_base.eager_only_combinations()) def testDistributeFromInterleave(self): service = self.create_cluster(1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): ds = dataset_ops.Dataset.range(2) ds = ds.apply(data_service_ops.distribute(service)) return ds with self.assertRaisesRegex( errors.InvalidArgumentError, r"The `.distribute\(...\)` dataset " "transformation is not supported within tf.data functions"): ds = ds.interleave(interleave_fn, cycle_length=2) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply(data_service_ops.distribute(service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply(data_service_ops.distribute(service=""))
class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, requires_initialization): dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5) self.assertDatasetProduces( dataset, elem_sequence * num_repeats, requires_initialization=requires_initialization, num_test_iterations=2) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingFn(self, num_repeats, requires_initialization): def generator(): for i in range(1, 100): yield [i] * i elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingList(self, num_repeats, requires_initialization): generator = lambda: [[i] * i for i in range(1, 100)] elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization): generator = lambda: np.arange(100, dtype=np.int64) elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingGeneratorExpression(self, num_repeats, requires_initialization): # NOTE(mrry): Generator *expressions* are not repeatable (or in general # reusable), because they eagerly evaluate the `for` expression as # `iter(range(1, 100))` and discard the means of reconstructing # `range(1, 100)`. Wrapping the generator expression in a `lambda` makes # it repeatable. generator = lambda: ([i] * i for i in range(1, 100)) elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate(test_base.default_test_combinations()) def testFromMultipleConcurrentGenerators(self): num_inner_repeats = 5 num_outer_repeats = 100 def generator(): for i in range(1, 10): yield ([i] * i, [i, i**2, i**3]) input_list = list(generator()) # The interleave transformation is essentially a flat map that # draws from multiple input datasets concurrently (in a cyclic # fashion). By placing `Dataset.from_generator()` inside an # interleave, we test its behavior when multiple iterators are # active at the same time; by additionally prefetching inside the # interleave, we create the possibility of parallel (modulo GIL) # invocations to several iterators created by the same dataset. def interleave_fn(_): return (dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64, dtypes.int64), output_shapes=([None], [3])).repeat(num_inner_repeats).prefetch(5)) dataset = dataset_ops.Dataset.range(num_outer_repeats).interleave( interleave_fn, cycle_length=10, block_length=len(input_list)) get_next = self.getNext(dataset) for _ in range(num_inner_repeats * num_outer_repeats): for elem in input_list: val0, val1 = self.evaluate(get_next()) self.assertAllEqual(elem[0], val0) self.assertAllEqual(elem[1], val1) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # TODO(b/67868766): Reenable this when the source of flakiness is discovered. def _testFromGeneratorsRunningInParallel(self): num_parallel_iterators = 3 # Define shared state that multiple iterator instances will access to # demonstrate their concurrent activity. lock = threading.Lock() condition = threading.Condition(lock) next_ticket = [0] # GUARDED_BY(lock) def generator(): # NOTE(mrry): We yield one element before the barrier, because # the current implementation of `Dataset.interleave()` must # fetch one element from each incoming dataset to start the # prefetching. yield 0 # Define a barrier that `num_parallel_iterators` iterators must enter # before any can proceed. Demonstrates that multiple iterators may be # active at the same time. condition.acquire() ticket = next_ticket[0] next_ticket[0] += 1 if ticket == num_parallel_iterators - 1: # The last iterator to join the barrier notifies the others. condition.notify_all() else: # Wait until the last iterator enters the barrier. while next_ticket[0] < num_parallel_iterators: condition.wait() condition.release() yield 1 # As in `testFromMultipleConcurrentGenerators()`, we use a combination of # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple # iterators to be active concurrently. def interleave_fn(_): return dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) dataset = dataset_ops.Dataset.range(num_parallel_iterators).interleave( interleave_fn, cycle_length=num_parallel_iterators, block_length=1) get_next = self.getNext(dataset) for elem in [0, 1]: for _ in range(num_parallel_iterators): self.assertAllEqual(elem, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorImplicitConversion(self): def generator(): yield [1] yield [2] yield [3] for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtype, output_shapes=[1]) get_next = self.getNext(dataset) for expected in [[1], [2], [3]]: next_val = self.evaluate(get_next()) self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) self.assertAllEqual(expected, next_val) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorString(self): def generator(): yield "foo" yield b"bar" yield u"baz" dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.string, output_shapes=[]) self.assertDatasetProduces(dataset, expected_output=[b"foo", b"bar", b"baz"]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorTypeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) yield np.array([4, 5, 6], dtype=np.int64) yield "ERROR" yield np.array([7, 8, 9], dtype=np.int64) dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64, output_shapes=[3]) get_next = self.getNext(dataset) self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) with self.assertRaisesOpError("The expected type was int64"): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorShapeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) yield np.array([4, 5, 6], dtype=np.int64) yield np.array([7, 8, 9, 10], dtype=np.int64) yield np.array([11, 12, 13], dtype=np.int64) dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64, output_shapes=[3]) get_next = self.getNext(dataset) self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorStructureError(self): def generator(): yield 1, 2 yield 3, 4 yield 5 yield 6, 7, 8 yield 9, 10 dataset = dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64, dtypes.int64)) get_next = self.getNext(dataset) self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) with self.assertRaisesOpError( r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) with self.assertRaisesOpError( r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorHeterogeneous(self): def generator(): yield 1 yield [2, 3] dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) self.assertDatasetProduces(dataset, expected_output=[1, [2, 3]]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorStopShort(self): def generator(): yield 0 yield 1 yield 2 dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) get_next = self.getNext(dataset) self.assertAllEqual(0, self.evaluate(get_next())) self.assertAllEqual(1, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorDestructorCalled(self): # Use an `Event` to signal that the generator has been deleted. event = threading.Event() class GeneratorWrapper(object): def __iter__(self): return self def next(self): return self.__next__() def __next__(self): return 42 def __del__(self): event.set() dataset = dataset_ops.Dataset.from_generator( GeneratorWrapper, output_types=dtypes.int64).take(2) get_next = self.getNext(dataset) self.assertAllEqual(42, self.evaluate(get_next())) self.assertAllEqual(42, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Test that `GeneratorWrapper` object is destroyed when the # iterator terminates (and the generator iterator is deleted). self.assertTrue(event.is_set()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorWithArgs(self): def flat_map_fn(elem): def generator_with_arg(n): for _ in range(n): yield np.array(n, dtype=np.int64) return dataset_ops.Dataset.from_generator( generator_with_arg, output_types=dtypes.int64, output_shapes=(), args=(elem, )) dataset = dataset_ops.Dataset.range(5).flat_map(flat_map_fn) self.assertDatasetProduces( dataset, expected_output=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorWithTwoArgs(self): def flat_map_fn(elem, message): def generator_with_arg(n, msg): for i in range(n): yield i, msg return dataset_ops.Dataset.from_generator( generator_with_arg, output_types=(dtypes.int64, dtypes.string), output_shapes=((), ()), args=(elem, message)) dataset = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(5), dataset_ops.Dataset.from_tensors("Hi!").repeat(None) )).flat_map(flat_map_fn) self.assertDatasetProduces(dataset, expected_output=[(0, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]) @combinations.generate(test_base.default_test_combinations()) def testGeneratorDatasetFinalizeFunctionCalled(self): # NOTE(mrry): This test tests the internal `_GeneratorDataset`, # which affords more control over what the finalize function can do than # the `Dataset.from_generator()` wrapper. # Use an `Event` to signal that the generator has been deleted. event = threading.Event() def finalize_fn(_): def finalize_py_func(): event.set() return 0 return script_ops.py_func(finalize_py_func, [], [dtypes.int64], stateful=True) dummy = constant_op.constant(37) dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, finalize_fn).take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) self.assertAllEqual(37, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testSharedName(self): def generator(): for _ in range(10): yield [20] dataset = dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64)) get_next = self.getNext(dataset, requires_initialization=True, shared_name="shared_dataset") self.assertAllEqual([20], self.evaluate(get_next()))
class RangeTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStop(self, output_type): stop = 5 dataset = dataset_ops.Dataset.range(stop, output_type=output_type) expected_output = np.arange(stop, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStartStop(self, output_type): start, stop = 2, 5 dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type) expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStartStopStep(self, output_type): start, stop, step = 2, 10, 2 dataset = dataset_ops.Dataset.range(start, stop, step, output_type=output_type) expected_output = np.arange(start, stop, step, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testZeroStep(self, output_type): start, stop, step = 2, 10, 0 with self.assertRaises(errors.InvalidArgumentError): dataset = dataset_ops.Dataset.range(start, stop, step, output_type=output_type) self.evaluate(dataset._variant_tensor) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testNegativeStep(self, output_type): start, stop, step = 2, 10, -1 dataset = dataset_ops.Dataset.range(start, stop, step, output_type=output_type) expected_output = np.arange(start, stop, step, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStopLessThanStart(self, output_type): start, stop = 10, 2 dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type) expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStopLessThanStartWithPositiveStep(self, output_type): start, stop, step = 10, 2, 2 dataset = dataset_ops.Dataset.range(start, stop, step, output_type=output_type) expected_output = np.arange(start, stop, step, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(output_type=[ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ]))) def testStopLessThanStartWithNegativeStep(self, output_type): start, stop, step = 10, 2, -1 dataset = dataset_ops.Dataset.range(start, stop, step, output_type=output_type) expected_output = np.arange(start, stop, step, dtype=output_type.as_numpy_dtype) self.assertDatasetProduces(dataset, expected_output=expected_output) self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) @combinations.generate(test_base.default_test_combinations()) def testName(self): dataset = dataset_ops.Dataset.range(5, name="range") self.assertDatasetProduces(dataset, list(range(5)))
class DynamicShardingTest(data_service_test_base.TestBase, parameterized.TestCase): def _make_dynamic_sharding_dataset(self, dataset, cluster): return self.make_distributed_dataset( dataset, cluster, processing_mode=data_service_ops.ShardingPolicy.DYNAMIC, job_name="job_name") @combinations.generate(test_base.default_test_combinations()) def testBasic(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testTensorSlices(self): cluster = data_service_test_base.TestCluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, vals, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testParallelInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x]), num_parallel_calls=dataset_ops.AUTOTUNE) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testFlatMap(self): cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testGroupByWindow(self): # Verify that split providers are not propagated into iterators created for # the reduce datasets created by the reduce_fn in group_by_window. cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) def reduce_fn(_, window): return dataset_ops.Dataset.zip( (window, dataset_ops.Dataset.range(100))) ds = ds.group_by_window(lambda x: 0, reduce_fn, window_size=3) ds = self._make_dynamic_sharding_dataset(ds, cluster) # This will fail if the tensor_slices split provider ispropagated into the # `reduce_fn`, since the `zip` requires either 0 or 2 split providers. self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testRepeatBeforeDistribution(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testRepeatAfterDistribution(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements) ds = self._make_dynamic_sharding_dataset(ds, cluster) ds = ds.repeat(num_repeats) self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testForeverRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_elements = 20 elements_to_read = 1000 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self._make_dynamic_sharding_dataset(ds, cluster) get_next = self.getNext(ds) results = {} for _ in range(elements_to_read): val = self.evaluate(get_next()) if val not in results: results[val] = 0 results[val] += 1 for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2) @combinations.generate(test_base.default_test_combinations()) def testForeverRepeatFewElements(self): num_workers = 5 cluster = data_service_test_base.TestCluster(num_workers=num_workers) # Less than the number of workers, so that some workers get zero elements on # the first repetition. num_elements = 1 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self._make_dynamic_sharding_dataset(ds, cluster) get_next = self.getNext(ds) for _ in range(20): self.assertEqual(self.evaluate(get_next()), 0) # Stop all but one worker and check that we can still read. for i in range(num_workers - 1): cluster.workers[i].stop() for _ in range(20): self.assertEqual(self.evaluate(get_next()), 0) @combinations.generate(test_base.default_test_combinations()) def testShuffleAndRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle( num_elements).repeat(num_repeats) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZip(self): num_elements = 10 cluster = data_service_test_base.TestCluster(num_workers=1) a = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip((a, a)) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces( ds, list(zip(range(num_elements), range(num_elements)))) @combinations.generate(test_base.default_test_combinations()) def testNestedZip(self): num_elements = 10 cluster = data_service_test_base.TestCluster(num_workers=1) a = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip((a, a)) ds = dataset_ops.Dataset.zip((a, a, ds, a)) ds = self._make_dynamic_sharding_dataset(ds, cluster) b = list(range(10)) self.assertDatasetProduces(ds, list(zip(b, b, zip(b, b), b))) @combinations.generate(test_base.default_test_combinations()) def testImbalancedZip(self): smaller_num_elements = 200 larger_num_elements = 1000 cluster = data_service_test_base.TestCluster(num_workers=1) a = dataset_ops.Dataset.range(smaller_num_elements) b = dataset_ops.Dataset.range(larger_num_elements) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces( ds, list(zip(range(smaller_num_elements), range(smaller_num_elements)))) @combinations.generate(test_base.default_test_combinations()) def testImbalancedZipMultiWorker(self): smaller_num_elements = 200 larger_num_elements = 1000 cluster = data_service_test_base.TestCluster(num_workers=3) a = dataset_ops.Dataset.range(smaller_num_elements) b = dataset_ops.Dataset.range(larger_num_elements) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) # Cannot assert specific elements because the range datasets are split # nondeterministically and may not line up. self.assertLen(self.getDatasetOutput(ds), smaller_num_elements) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentRates(self): cluster = data_service_test_base.TestCluster(num_workers=3) a = dataset_ops.Dataset.range(100) b = dataset_ops.Dataset.range(100).filter( lambda x: math_ops.equal(x % 10, 0)) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertLen(self.getDatasetOutput(ds), 10) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentRepeats(self): cluster = data_service_test_base.TestCluster(num_workers=3) a = dataset_ops.Dataset.range(50) b = dataset_ops.Dataset.range(10).repeat(10) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertLen(self.getDatasetOutput(ds), 50) @combinations.generate(test_base.default_test_combinations()) def testSampleFromDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=3) num_samples = 200 weights = [.6, .3, .1] classes = len(weights) # Create a dataset that samples each integer in `[0, num_datasets)` # with probability given by `weights[i]`. ds = dataset_ops.Dataset.sample_from_datasets([ dataset_ops.Dataset.from_tensors(i).repeat() for i in range(classes) ], weights) ds = self._make_dynamic_sharding_dataset(ds, cluster) ds = ds.take(num_samples) freqs = np.zeros([classes]) for v in self.getDatasetOutput(ds): freqs[v] += 1 self.assertGreater(freqs[0], freqs[1]) self.assertGreater(freqs[1], freqs[2]) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3]))) def testChooseFromDatasets(self, num_workers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) words = [b"foo", b"bar", b"baz"] datasets = [ dataset_ops.Dataset.from_tensors(w).repeat() for w in words ] choice_array = np.random.randint(3, size=(15, ), dtype=np.int64) choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) ds = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset) ds = self._make_dynamic_sharding_dataset(ds, cluster) expected = [words[i] for i in choice_array] if compat.forward_compatible(2022, 6, 6): expected *= num_workers assert_items_equal = (num_workers > 1) self.assertDatasetProduces(ds, expected, assert_items_equal=assert_items_equal) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testEnumerateReplicateOnSplit(self): if not compat.forward_compatible(2022, 6, 6): self.skipTest("Replicate on split is not yet available.") num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers) ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]).repeat() ds = ds.enumerate() ds = self._make_dynamic_sharding_dataset(ds, cluster) get_next = self.getNext(ds) counts = collections.defaultdict(int) while True: i, _ = self.evaluate(get_next()) counts[i] += 1 # Read until all workers have reached enumeration index 10. if counts[10] == num_workers: break for i in range(10): self.assertEqual(counts[i], num_workers) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3]))) def testConcatenate(self, num_workers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) a = dataset_ops.Dataset.range(100) b = dataset_ops.Dataset.range(100, 200) ds = a.concatenate(b) ds = self._make_dynamic_sharding_dataset(ds, cluster) assert_items_equal = (num_workers > 1) self.assertDatasetProduces(ds, list(range(200)), assert_items_equal=assert_items_equal) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(already_written=[True, False])) ) def testSnapshot(self, already_written): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) ds = dataset_ops.Dataset.range(100) ds = ds.snapshot(self.get_temp_dir()) if already_written: # Materialize the snapshot. self.getDatasetOutput(ds) ds = self._make_dynamic_sharding_dataset(ds, cluster) error_regex = "Splitting is not implemented for snapshot datasets" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testDistributedDataset(self): cluster_1 = data_service_test_base.TestCluster(num_workers=1) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 numbers = [1 * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(numbers) ds = self.make_distributed_dataset( ds, cluster_1, processing_mode=data_service_ops.ShardingPolicy.OFF) ds = ds.map(lambda x: x + 1) ds = self._make_dynamic_sharding_dataset(ds, cluster_2) error_regex = ("Cannot create split providers for dataset " + "of type DataServiceDataset") with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testFlatMapWithRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=3) ds = dataset_ops.Dataset.range(5) def flat_map_fn(_): return dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]).repeat(10) ds = ds.flat_map(flat_map_fn) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, [b"a", b"b", b"c"] * 50, assert_items_equal=True)
class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase): def __init__(self, methodName="runTest"): # pylint: disable=invalid-name super(RemoteReplicateTest, self).__init__(methodName) self._cached_server1 = server_lib.Server.create_local_server() self._cached_server2 = server_lib.Server.create_local_server() self._cached_server1_target = self._cached_server1.target[len("grpc://"):] self._cached_server2_target = self._cached_server2.target[len("grpc://"):] self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME def setUp(self): super(RemoteReplicateTest, self).setUp() # Start the local server. local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie() context.set_server_def( server_def=_get_server_def( JOB_NAME, local_server_port=local_port, remote_server_addresses=[ self._cached_server1_target, self._cached_server2_target ], task_index=0)) @combinations.generate( combinations.combine(tf_api_version=[2], mode=["eager"])) def testBasic(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(100)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(100)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(100)) @combinations.generate( combinations.combine(tf_api_version=[2], mode=["eager"])) def testMap(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(0, 200, 2)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(0, 200, 2)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(0, 200, 2)) @combinations.generate( combinations.combine(tf_api_version=[2], mode=["eager"])) def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) # We don't support stateful ops in functions as of now. with self.assertRaises(errors.FailedPreconditionError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) self.evaluate(replicated_ds[self._device1]._variant_tensor) @combinations.generate( combinations.combine(tf_api_version=[2], mode=["eager"])) def testAllowStatefulOp(self): with compat.forward_compatibility_horizon(2019, 9, 12): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_allow_stateful = True dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): get_next0() get_next1() get_next2()
class MakeDeterministicTest(test_base.DatasetTestBase, parameterized.TestCase): def _set_seed(self): # Set the seed, since in graph mode some non-random dataset ops call # tf.compat.v1.get_seed to copy the seed to a Defun. Calling get_seed raises # an error with determinism if no seed is set. # TODO(reedwm): Ensure such dataset ops do not raise an error when no seed # is set. random_seed.set_random_seed(1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(use_function=[False, True], use_legacy_interleave=[False, True]))) def test_stateful_ops_interleave(self, use_function, use_legacy_interleave): with test_util.deterministic_ops(): v = variables.Variable(0.) def map_fn(x): v.assign_add(1.) return (x, v.read_value()) def interleave_fn(x): del x return dataset_ops.Dataset.range(2).map(map_fn) if use_function: map_fn = def_function.function(map_fn) interleave_fn = def_function.function(interleave_fn) dataset = dataset_ops.Dataset.range(5) if use_legacy_interleave: dataset = dataset.apply( interleave_ops.parallel_interleave(interleave_fn, cycle_length=5)) else: dataset = dataset.interleave(interleave_fn, cycle_length=5, num_parallel_calls=3) self.evaluate(variables.global_variables_initializer()) expected_output = list(zip([0] * 5 + [1] * 5, range(1, 11))) self.assertDatasetProduces(dataset, expected_output=expected_output, requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(use_function=[False, True]))) def test_stateful_ops_map(self, use_function): with test_util.deterministic_ops(): v = variables.Variable(0.) def map_fn(x): v.assign_add(1.) return (x, v.read_value()) if use_function: map_fn = def_function.function(map_fn) dataset = dataset_ops.Dataset.range(5) dataset = dataset.map(map_fn, num_parallel_calls=5) self.evaluate(variables.global_variables_initializer()) expected_output = list(zip(range(0, 5), range(1, 6))) self.assertDatasetProduces(dataset, expected_output=expected_output, requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(use_function=[False, True]))) def test_stateful_ops_batch(self, use_function): with test_util.deterministic_ops(): v = variables.Variable(0.) def map_fn(x): return (x, v.read_value()) if use_function: map_fn = def_function.function(map_fn) dataset = dataset_ops.Dataset.range(5) dataset = dataset.map(map_fn) dataset = dataset.apply(testing.assert_next(["Batch"])) dataset = dataset.batch(2, num_parallel_calls=2) self.evaluate(variables.global_variables_initializer()) expected_output = [ (np.array([0, 1]), np.array([0, 0])), (np.array([2, 3]), np.array([0, 0])), (np.array([4]), np.array([0])), ] self.assertDatasetProduces(dataset, expected_output=expected_output, requires_initialization=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(use_function=[False, True], use_legacy_map_and_batch=[False, True]))) def test_stateful_ops_map_and_batch(self, use_function, use_legacy_map_and_batch): with test_util.deterministic_ops(): v = variables.Variable(0.) def map_fn(x): v.assign_add(1.) return (x, v.read_value()) if use_function: map_fn = def_function.function(map_fn) dataset = dataset_ops.Dataset.range(5) if use_legacy_map_and_batch: dataset = dataset.apply( batching.map_and_batch(map_fn, 2, num_parallel_calls=5)) else: dataset = dataset.map(map_fn, num_parallel_calls=5) dataset = dataset.batch(2) self.evaluate(variables.global_variables_initializer()) expected_output = [ (np.array([0, 1]), np.array([1, 2])), (np.array([2, 3]), np.array([3, 4])), (np.array([4]), np.array([5])), ] self.assertDatasetProduces(dataset, expected_output=expected_output, requires_initialization=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(use_function=[False, True], use_legacy_interleave=[False, True]))) def test_no_stateful_ops_interleave(self, use_function, use_legacy_interleave): self._set_seed() with test_util.deterministic_ops(): def interleave_fn(x): del x return dataset_ops.Dataset.range(2) if use_function: interleave_fn = def_function.function(interleave_fn) dataset = dataset_ops.Dataset.range(5) if use_legacy_interleave: dataset = dataset.apply( testing.assert_next(["LegacyParallelInterleaveV2"])) dataset = dataset.apply( interleave_ops.parallel_interleave(interleave_fn, cycle_length=5)) else: dataset = dataset.apply( testing.assert_next(["ParallelInterleave"])) dataset = dataset.interleave(interleave_fn, cycle_length=5, num_parallel_calls=3) self.evaluate(variables.global_variables_initializer()) self.assertDatasetProduces(dataset, expected_output=[0] * 5 + [1] * 5) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(use_function=[False, True]))) def test_no_stateful_ops_map(self, use_function): self._set_seed() with test_util.deterministic_ops(): def map_fn(x): return x + 1 if use_function: map_fn = def_function.function(map_fn) dataset = dataset_ops.Dataset.range(5) dataset = dataset.apply(testing.assert_next(["ParallelMap"])) dataset = dataset.map(map_fn, num_parallel_calls=5) self.evaluate(variables.global_variables_initializer()) expected_output = range(1, 6) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(use_function=[False, True], use_control_flow=[False, True]))) def test_text_line_dataset(self, use_function, use_control_flow): self._set_seed() with test_util.deterministic_ops(): def write_nums_to_file(filename, numbers): path = os.path.join(self.get_temp_dir(), filename) with open(path, "w") as f: f.write("\n".join(str(n) for n in numbers)) return path f1 = write_nums_to_file("f1", (1, 2, 3)) f2 = write_nums_to_file("f2", (4, 5, 6)) f3 = write_nums_to_file("f3", (7, 8, 9)) if use_control_flow: def interleave_fn(filename): # Test function that uses control flow. The True branch is never taken concat = string_ops.string_join([filename, "abc"]) return control_flow_ops.cond( math_ops.equal(filename, "abc"), lambda: reader_ops.TextLineDataset(concat), lambda: reader_ops.TextLineDataset(filename)) else: def interleave_fn(filename): return reader_ops.TextLineDataset(filename) if use_function: interleave_fn = def_function.function(interleave_fn) dataset = dataset_ops.Dataset.from_tensor_slices([f1, f2, f3]) dataset = dataset.apply(testing.assert_next(["ParallelInterleave" ])) dataset = dataset.interleave(interleave_fn, cycle_length=3, num_parallel_calls=3) self.assertDatasetProduces( dataset, expected_output=["1", "4", "7", "2", "5", "8", "3", "6", "9"]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(local_determinism=[None, True, False], global_determinism=[True, False]))) def test_deterministic_attribute(self, local_determinism, global_determinism): self._set_seed() with test_util.deterministic_ops(): def sleep(x): time.sleep(0.1) return x def map_function(x): if math_ops.equal(x, 0): return script_ops.py_func(sleep, [x], x.dtype, stateful=False) else: return x dataset = dataset_ops.Dataset.range(100) dataset = dataset.map(map_function, num_parallel_calls=2, deterministic=local_determinism) opts = options_lib.Options() opts.deterministic = global_determinism dataset = dataset.with_options(opts) self.assertDatasetProduces(dataset, expected_output=range(100)) @combinations.generate(test_base.default_test_combinations()) def test_rewrite_prefetch(self): with test_util.deterministic_ops(): v = variables.Variable(-1, dtype=dtypes.int64) def map_fn(x): v.assign(x) return x dataset = dataset_ops.Dataset.range(5) dataset = dataset.map(map_fn) dataset = dataset.prefetch(5) self.evaluate(variables.global_variables_initializer()) get_next = self.getNext(dataset, requires_initialization=True) self.assertEqual(self.evaluate(v), -1) self.assertEqual(self.evaluate(get_next()), 0) time.sleep(0.01) self.assertEqual(self.evaluate(v), 0) self.assertEqual(self.evaluate(get_next()), 1) time.sleep(0.01) self.assertEqual(self.evaluate(v), 1) @combinations.generate(test_base.default_test_combinations()) def test_no_determinism(self): config.disable_op_determinism() v = variables.Variable(0.) def interleave_fn(x): del x v.assign(1.) return dataset_ops.Dataset.range(2) dataset = dataset_ops.Dataset.range(5) dataset = dataset.apply(testing.assert_next(["ParallelInterleave"])) dataset = dataset.interleave(interleave_fn, cycle_length=5, num_parallel_calls=3) self.evaluate(variables.global_variables_initializer()) expected_output = [0] * 5 + [1] * 5 self.assertDatasetProduces(dataset, expected_output=expected_output, requires_initialization=True)
def _test_combinations(): return combinations.combine(tf_api_version=[1], mode=["graph"])
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDispatcherStop(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBeforeReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBetweenIterations(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherManyRestarts(self): cluster = self.create_cluster(num_workers=1) num_elements_start = 10 num_elements_end = 15 datasets = [] for num_elements in range(num_elements_start, num_elements_end): datasets.append( self.make_distributed_range_dataset(num_elements, cluster)) cluster.restart_dispatcher() for ds, num_elements in zip( datasets, range(num_elements_start, num_elements_end)): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndWorkerRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndMultiWorkerRestart(self): num_workers = 2 cluster = self.create_cluster(num_workers=num_workers) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.restart_worker(worker_index=worker_index) for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.restart_worker(worker_index=worker_index) for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = self.create_cluster(num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join() @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.add_worker() # Wait for the new worker to register with the dispatcher. while cluster.num_registered_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]), data_service_test_base.all_cluster_configurations()) ) def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.restart_worker(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) ds = range_dataset.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.target, job_name="test")) iterator = iter(ds) for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) cluster.restart_dispatcher() ds = range_dataset.apply( data_service_ops.distribute(processing_mode="distributed_epoch", service=cluster.target, job_name="test")) with self.assertRaisesOpError( "already an existing job with that name " "using processing mode <parallel_epochs>"): next(iter(ds)).numpy() @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR]))) def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster = self.create_cluster(num_workers=0, work_dir=work_dir, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) it = iter(ds) cluster.add_worker() self.assertAllEqual(next(it), tensor)
class CollectiveOpsV1(object): all_reduce = _collective_ops.all_reduce all_gather = _collective_ops.all_gather class CollectiveOpsV2(object): all_reduce = _collective_ops.all_reduce_v2 all_gather = _collective_ops.all_gather_v2 @combinations.generate( combinations.combine(collective_ops=[ combinations.NamedObject('v1', CollectiveOpsV1), combinations.NamedObject('v2', CollectiveOpsV2) ], mode='eager')) class CollectiveOpsTest(test.TestCase, parameterized.TestCase): def setUp(self): _setup_context() super().setUp() def testReduce(self, collective_ops): @def_function.function def run_all_reduce_1cpu(): with ops.device('/device:CPU:0'): in_value = constant_op.constant([1.]) group_size = 1 group_key = 1 instance_key = 1
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(rebatched_dataset)) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @combinations.generate(test_base.default_test_combinations()) def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegexp(ValueError, ("You can fix the issue " "by adding the `batch`")): distribute._RebatchDataset(dataset, num_replicas=4) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [] i = 0 for _ in range(32): # number of steps # first four minibatches have seven elements for _ in range(4): expected_output.append([k for k in range(i, i + 7)]) i += 7 # last minibatch has four elements expected_output.append([k for k in range(i, i + 4)]) i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) # This will rebatch into sub-batches of size 4, since # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get # data. expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)] expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend([[k for k in range(i, i + 4)] for i in range(16, 32, 4)]) expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ ( [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: { "a": x, "b": { "c": x } }).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ { "a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": { "c": [k for k in range(i, i + 8)] } } for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # if drop_remainder, the final partial batch is dropped, even though it # makes up a complete minibatch. expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 8 is split over 4 replicas expected_output.extend([[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 2 is split over 4 replicas expected_output += [[32], [33], [], []] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4) ] # generates 8 elements for i in range(0, 128, 32) ] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4) ] # generates 2 elements for i in range(0, 128, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices( ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths)) dataset = dataset.batch(32, drop_remainder=True) # The map changes the internal representation of the ragged tensor. # This test will fail if we don't normalize the tensor representation. dataset = dataset.map(lambda x: x) dataset = distribute._RebatchDataset(dataset, num_replicas=8) # After rebatching, batch size is now 4. expected_output = [] value_index = 0 for batch_row_lengths in row_lengths.reshape((-1, 4)): num_values = np.sum(batch_row_lengths) expected_output.append( ragged_tensor.RaggedTensor.from_row_lengths( values[value_index:(value_index + num_values)], batch_row_lengths)) value_index += num_values self.assertDatasetProduces(dataset, expected_output)
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testBasic(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0])) def dataset_fn(count=5, buffer_size=None, seed=0): repeat_dataset = (dataset_ops.Dataset.from_tensor_slices( components).repeat(count)) if buffer_size: shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed) self.assertEqual( tuple([c.shape[1:] for c in components]), dataset_ops.get_legacy_output_shapes(shuffle_dataset)) return shuffle_dataset else: return repeat_dataset # First run without shuffling to collect the "ground truth". get_next = self.getNext(dataset_fn()) unshuffled_elements = [] for _ in range(20): unshuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Assert that the shuffled dataset has the same elements as the # "ground truth". get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) shuffled_elements = [] for _ in range(20): shuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements)) # Assert that shuffling twice with the same seeds gives the same sequence. get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) reshuffled_elements_same_seed = [] for _ in range(20): reshuffled_elements_same_seed.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) # Assert that shuffling twice with a different seed gives a different # permutation of the same elements. get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) reshuffled_elements_different_seed = [] for _ in range(20): reshuffled_elements_different_seed.append(self.evaluate( get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) self.assertAllEqual(sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) # Assert that the shuffled dataset has the same elements as the # "ground truth" when the buffer size is smaller than the input # dataset. get_next = self.getNext(dataset_fn(buffer_size=2, seed=37)) reshuffled_elements_small_buffer = [] for _ in range(20): reshuffled_elements_small_buffer.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) # Test the case of shuffling an empty dataset. get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) def testSeedZero(self): """Test for same behavior when the seed is a Python or Tensor zero.""" iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=0)) get_next = iterator.get_next() elems = [] with self.cached_session() as sess: for _ in range(10): elems.append(sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)) get_next = iterator.get_next() with self.cached_session() as sess: sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) for elem in elems: self.assertEqual(elem, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testDefaultArguments(self): components = [0, 1, 2, 3, 4] dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle( 5).repeat() get_next = self.getNext(dataset) counts = collections.defaultdict(lambda: 0) for _ in range(10): for _ in range(5): counts[self.evaluate(get_next())] += 1 for i in range(5): self.assertEqual(10, counts[i]) @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine(reshuffle=[True, False]), combinations.combine(graph_seed=38, op_seed=None) + combinations.combine(graph_seed=None, op_seed=42) + combinations.combine(graph_seed=38, op_seed=42))) def testShuffleSeed(self, reshuffle, graph_seed, op_seed): results = [] for _ in range(2): with ops.Graph().as_default() as g: random_seed.set_random_seed(graph_seed) dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() run_results = [] with self.session(graph=g) as sess: for _ in range(30): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertAllEqual(results[0], results[1]) # TODO(b/117581999): enable this test for eager-mode. @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine(reshuffle=[True, False], initializable=[True, False]))) def testMultipleIterators(self, reshuffle, initializable): with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(100).shuffle( 10, reshuffle_each_iteration=reshuffle).repeat(3) if initializable: iterators = [ dataset_ops.make_initializable_iterator(dataset) for _ in range(2) ] else: iterators = [ dataset_ops.make_one_shot_iterator(dataset) for _ in range(2) ] results = [] with self.session(graph=g) as sess: for iterator in iterators: if initializable: sess.run(iterator.initializer) next_element = iterator.get_next() run_results = [] for _ in range(300): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertNotEqual(results[0], results[1]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleRepeatEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2) next_element = self.getNext(dataset) first_epoch = [] for _ in range(10): first_epoch.append(self.evaluate(next_element())) second_epoch = [] for _ in range(10): second_epoch.append(self.evaluate(next_element())) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate( combinations.times( combinations.combine(tf_api_version=2, mode="eager"), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleIterationEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle) first_epoch = [] for elem in dataset: first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset: second_epoch.append(elem.numpy()) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2ResourceCapture(self): def make_dataset(): ids = dataset_ops.Dataset.range(10) ids = ids.shuffle(1) def interleave_fn(dataset, _): return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.interleave(functools.partial(interleave_fn, ids)) return dataset results = [] for elem in make_dataset(): results.append(elem.numpy()) self.assertAllEqual(results, range(10)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleSeparateTransformations(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10) first_epoch = [] for elem in dataset.shuffle(10, seed=seed, reshuffle_each_iteration=reshuffle): first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset.shuffle(10, seed=seed, reshuffle_each_iteration=reshuffle): second_epoch.append(elem.numpy()) self.assertEqual(first_epoch != second_epoch, seed is None) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2InFunction(self): counter_var = variables.Variable(0) @function.defun def consume(): ds = dataset_ops.Dataset.range(10) ds = ds.shuffle(1) for _ in ds: counter_var.assign(counter_var + 1) consume() self.assertAllEqual(self.evaluate(counter_var), 10) @combinations.generate(test_base.default_test_combinations()) def testEmptyDataset(self): dataset = dataset_ops.Dataset.from_tensors(1) def map_fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return x dataset = dataset.map(map_fn) dataset = dataset.cache() dataset = dataset.shuffle(buffer_size=10).repeat() get_next = self.getNext(dataset) # First time around, we get an error for the failed assertion. with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) # Second time around, we get an EOF because the cached dataset is empty. with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(reshuffle=[True, False]))) def testRerandomizeOnReplicate(self, reshuffle): random_seed.set_random_seed(None) # When no seeds are fixed, each instantiation of the shuffle dataset should # produce elements in a different order. num_elements = 100 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle) shuffle_1 = self.getDatasetOutput(dataset) dataset = self.graphRoundTrip(dataset, allow_stateful=True) shuffle_2 = self.getDatasetOutput(dataset) self.assertCountEqual(shuffle_1, shuffle_2) self.assertNotEqual(shuffle_1, shuffle_2)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidCompression(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "Invalid compression argument"): self.make_distributed_range_dataset(10, cluster, compression="foo") @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = data_service_test_base.TestCluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( init_source=["textfile", "keyvaluetensor", "dataset"]))) def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(value_rank=[0, 1]))) def testDistributeMutableHashTable(self, value_rank): def value(v): for _ in range(value_rank): v = [v, v] return v v1 = value(10) v2 = value(11) default_value = value(-1) cluster = data_service_test_base.TestCluster(num_workers=1) table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64, default_value) self.evaluate(table.insert([0, 1], [v1, v2])) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [v1, v2, default_value], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle_seed=[None, 10]))) def testShuffleOrder(self, shuffle_seed): random_seed.set_random_seed(None) num_elements = 100 cluster = data_service_test_base.TestCluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements, seed=shuffle_seed) ds = self.make_distributed_dataset(ds, cluster) output = self.getDatasetOutput(ds) # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) if shuffle_seed is None: self.assertNotEqual(first_order, second_order) else: self.assertEqual(first_order, second_order) @combinations.generate(test_base.default_test_combinations()) def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testConcurrentEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_datasets = 3 get_nexts = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) get_nexts.append(self.getNext(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = self.evaluate(get_nexts[dataset_ind]()) results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testEmptyJobNameDistribute(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "job_name must not be empty"): dataset_ops.Dataset.range(10).apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name="")) @combinations.generate(test_base.default_test_combinations()) def testEmptyJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "job_name must not be empty"): data_service_ops.from_dataset_id(dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name="") @combinations.generate(test_base.default_test_combinations()) def testNonStringJobNameDistribute(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "job_name must be a string"): dataset_ops.Dataset.range(10).apply( data_service_ops.distribute( processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name=constant_op.constant("foo"))) @combinations.generate(test_base.default_test_combinations()) def testNonStringJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "job_name must be a string"): data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name=constant_op.constant("foo")) @combinations.generate(test_base.default_test_combinations()) def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) results = [] for _ in range(num_elements // 5): results.append(self.evaluate(get_next_1())) results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.default_test_combinations()) def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.default_test_combinations()) def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_1())) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = self.make_distributed_range_dataset(10, cluster, job_name="job_name") num_epochs = 5 for _ in range(num_epochs): get_next = self.getNext(dataset) self.assertEqual(self.getIteratorOutput(get_next), list(range(10))) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(cluster.workers[0].num_tasks(), 2) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.workers[0].num_tasks() > 1: time.sleep(0.1) self.assertEqual(cluster.workers[0].num_tasks(), 1) @combinations.generate(test_base.eager_only_combinations()) def testGcAndRecreate(self): cluster = data_service_test_base.TestCluster( num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 1000 # Repeatedly create and garbage-collect the same job. for _ in range(3): ds = self.make_distributed_range_dataset(num_elements, cluster, job_name="test") it = iter(ds) for _ in range(50): next(it) del it # Wait for the task to be garbage-collected on all workers. while cluster.num_tasks_on_workers() > 0: time.sleep(0.1) @combinations.generate(test_base.default_test_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = options_lib.Options() opts.deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = options_lib.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = data_service_test_base.TestCluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(external_state_policy=[ options_lib.ExternalStatePolicy.IGNORE, options_lib.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.default_test_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(options_lib.ExternalStatePolicy.FAIL) @combinations.generate(test_base.default_test_combinations()) def testDistributeFromInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(x): dataset = dataset_ops.Dataset.range(10 * x, 10 * x + 2) dataset = self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 10, 1, 11]) @combinations.generate(test_base.default_test_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.default_test_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.default_test_combinations()) def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.NotFoundError, "No credentials factory has been registered for protocol grp"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grp://" + cluster.dispatcher_address())) self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( ValueError, "should be a ShardingPolicy, `\"parallel_epochs\"`, or " "`\"distributed_epoch\"`. Got 'invalid'."): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdSharedJobs(self): cluster = data_service_test_base.TestCluster(num_workers=2) datasets = [ dataset_ops.Dataset.range(20, output_type=dtypes.int32), dataset_ops.Dataset.from_tensor_slices(list(range(20, 40))) ] dataset_ids = [] for ds in datasets: dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) dataset_ids.append(dataset_id) # Read from both jobs in parallel, with 2 consumers for each job. data_service_datasets = [] for _ in range(2): for dataset, dataset_id in zip(datasets, dataset_ids): ds = self.from_dataset_id("distributed_epoch", cluster, dataset_id, dataset.element_spec, job_name="shared_job") data_service_datasets.append(ds) ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets) ds = ds.interleave(lambda x: x, cycle_length=len(data_service_datasets)) self.assertDatasetProduces(ds, list(range(40)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testRegisteringDatasetAsTfFunction(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) register_func = def_function.function(self.register_dataset) dataset_id = register_func( (constant_op.constant("grpc"), constant_op.constant(cluster.dispatcher_address())), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, wrong_spec) if data_service_test_base.TRANSFER_PROTOCOL.value: with self.assertRaisesRegex(errors.InvalidArgumentError, "Data type mismatch at component 0"): self.evaluate(self.getNext(from_dataset_id_ds)()) else: with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.default_test_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2) self.assertEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = data_service_test_base.TestCluster( num_workers=cluster_1_size) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) get_next = self.getNext(ds) for _ in range(num_sizes): element = self.evaluate(get_next()) for _ in range(1, cluster_1_size): self.assertAllEqual(self.evaluate(get_next()), element) self.assertEmpty(self.getIteratorOutput(get_next)) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor]) @combinations.generate( combinations.times(test_base.graph_only_combinations(), combinations.combine(use_resource=False)) + combinations.times(test_base.default_test_combinations(), combinations.combine(use_resource=True))) def testVariables(self, use_resource): cluster = data_service_test_base.TestCluster(num_workers=1) if not use_resource: with variable_scope.variable_scope("foo", use_resource=False): v = variables.VariableV1(10, dtype=dtypes.int64) else: v = variables.Variable(10, dtype=dtypes.int64) ds = dataset_ops.Dataset.range(3) ds = ds.map(lambda x: x + v) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(v.initializer) self.assertDatasetProduces(ds, list(range(10, 13)), requires_initialization=True) @combinations.generate(test_base.graph_only_combinations()) def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) with self.assertRaisesRegex( ValueError, "In graph mode element_spec must be provided manually."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testElementSpecMixedMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) @def_function.function def get_dataset_id(): return data_service_ops.register_dataset( cluster.dispatcher_address(), ds) dataset_id = get_dataset_id() dataset_id_val = tensor_util.constant_value(dataset_id) with self.assertRaisesRegex( ValueError, "Failed to fetch element spec for dataset id " + str(dataset_id_val) + " from tf.data service. If the " "dataset was registered in graph mode or inside a " "tf.function, the `element_spec` must be specified as " "an argument to `from_dataset_id`."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) @combinations.generate(test_base.default_test_combinations()) def testNoShardingPolicy(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testCardinality(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = self.make_distributed_range_dataset(10, cluster) self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.UNKNOWN)
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.graph_only_combinations()) def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset_ops.make_one_shot_iterator(dataset).get_next() self.assertIsNone(gradients_impl.gradients(value, component)[0]) self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone( gradients_impl.gradients(value, [component, side])[0]) @combinations.generate(test_base.graph_only_combinations()) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = (dataset_ops.Dataset.from_tensor_slices( [0.0, 1.0, 2.0]).map(lambda x: x + var)) with self.assertRaisesRegex( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): dataset_ops.make_one_shot_iterator(dataset) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) tensor_components = tuple( [ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(tensor_components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square( z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip( components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] def consumer_thread(): try: results.append(sess.run(next_element)) except errors.OutOfRangeError: results.append(None) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() self.assertLen(results, num_threads) self.assertLen([None for r in results if r is None], num_threads - 1) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) # Test that subsequent attempts to use the iterator also fail. with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() @combinations.generate(test_base.graph_only_combinations()) def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() # Create two non-overlapping sessions that share the same iterator # resource on the same server, and verify that an action of the # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components).map( lambda x, y, z: (x, y, z)), shared_name="shared_iterator") init_op = iterator.initializer get_next = iterator.get_next() with session.Session(server.target) as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Re-initialize the iterator in the first session. sess.run(init_op) with ops.Graph().as_default(): # Re-define the iterator manually, without defining any of the # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: # Use the iterator without re-initializing in the second session. results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components)) get_next = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset_3), [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) get_next = iterator.get_next() self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual( [None], dataset_ops.get_legacy_output_shapes(iterator).as_list()) with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) # Initialize with one dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Initialize with a different dataset. sess.run(dataset_4_init_op) self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Reinitialize with the first dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIteratorWithFunctions(self): def g(): for i in range(10): yield i iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) next_element = iterator.get_next() with self.cached_session() as sess: dataset_1 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_1)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) dataset_2 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_2)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @combinations.generate(test_base.default_test_combinations()) def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ), (constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64), )))) # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int32), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleFuture(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) initializable_iterator = dataset_ops.make_initializable_iterator( dataset) structure_iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset)) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleError(self): dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices( [1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors( [1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_ops.make_one_shot_iterator( dataset_int_scalar).string_handle()) handle_float_vector = sess.run( dataset_ops.make_one_shot_iterator( dataset_float_vector).string_handle()) self.assertEqual( 1, sess.run(feedable_int_scalar.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) self.assertEqual( 2, sess.run(feedable_int_any.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run(feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/cpu:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call(args=[iterator_3_handle], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.session(config=worker_config) as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() s3 = server_lib.Server.create_local_server() cluster_def = cluster_pb2.ClusterDef() workers = cluster_def.job.add() workers.name = "worker" workers.tasks[0] = s1.target[len("grpc://"):] workers.tasks[1] = s2.target[len("grpc://"):] client = cluster_def.job.add() client.name = "client" client.tasks[0] = s3.target[len("grpc://"):] config = config_pb2.ConfigProto(cluster_def=cluster_def) worker_devices = [ "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) ] itr_handles = [] for device in worker_devices: with ops.device(device): src = dataset_ops.Dataset.from_tensor_slices([device]) itr = dataset_ops.make_one_shot_iterator(src) itr_handles.append(itr.string_handle()) targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) @function.Defun(dtypes.string) def loading_func(h): remote_itr = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(itr), dataset_ops.get_legacy_output_shapes(itr)) return remote_itr.get_next() def map_fn(target, handle): return functional_ops.remote_call(args=[handle], Tout=[dtypes.string], f=loading_func, target=target) with ops.device("/job:client"): client_dataset = dataset_ops.Dataset.zip( (targets, handles)).map(map_fn) itr = dataset_ops.make_initializable_iterator(client_dataset) n = itr.get_next() with session.Session(s3.target, config=config) as sess: sess.run(itr.initializer) expected_values = worker_devices for expected in expected_values: self.assertEqual((compat.as_bytes(expected), ), sess.run(n)) with self.assertRaises(errors.OutOfRangeError): sess.run(n) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/job:localhost/replica:0/task:0/cpu:0"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() def _encode_raw(byte_array): return bytes(bytearray(byte_array)) @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) iterator_3_handle_uint8 = parsing_ops.decode_raw( input_bytes=iterator_3_handle, out_type=dtypes.uint8) remote_op = functional_ops.remote_call( args=[iterator_3_handle_uint8], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.cached_session() as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [1]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) @combinations.generate(test_base.graph_only_combinations()) def testRepeatedGetNextWarning(self): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10)) warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=tensor_spec.TensorSpec( [], dtypes.float32), expected_output_classes=ops.Tensor, expected_output_types=dtypes.float32, expected_output_shapes=[[]]))) def testTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value_fn = lambda: constant_op.constant(37.0) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=sparse_tensor.SparseTensorSpec( [1], dtypes.int32), expected_output_classes=sparse_tensor.SparseTensor, expected_output_types=dtypes.int32, expected_output_shapes=[[1]]))) def testSparseTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(expected_element_structure={ "a": tensor_spec.TensorSpec([], dtypes.float32), "b": (tensor_spec.TensorSpec([1], dtypes.string), tensor_spec.TensorSpec([], dtypes.string)) }, expected_output_classes={ "a": ops.Tensor, "b": (ops.Tensor, ops.Tensor) }, expected_output_types={ "a": dtypes.float32, "b": (dtypes.string, dtypes.string) }, expected_output_shapes={ "a": [], "b": ([1], []) }))) def testNestedTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) } tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate(test_base.graph_only_combinations()) def testIteratorGetNextName(self): with ops.Graph().as_default(): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(37.0)) next_element = iterator.get_next(name="overridden_name") self.assertEqual("overridden_name", next_element.op.name) @combinations.generate( combinations.combine(tf_api_version=[1, 2], mode="eager", execution_mode=[context.ASYNC, context.SYNC])) def testIteratorEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for foo in iterator: self.assertEqual(val, foo.numpy()) val += 1 @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunction(self): queue = data_flow_ops.FIFOQueue(10, dtypes.int64) @def_function.function def fn(): dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for _ in range(10): queue.enqueue(next(iterator)) fn() for i in range(10): self.assertEqual(queue.dequeue().numpy(), i) @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunctionError(self): # In this test we verify that a function that raises an error ends up # properly deallocating the iterator resource. queue = data_flow_ops.FIFOQueue(10, dtypes.int64) queue.enqueue(0) def init_fn(n): return n def next_fn(_): ds = dataset_ops.Dataset.range(0) return next(iter(ds)) def finalize_fn(n): queue.enqueue(0) return n @def_function.function def fn(): output_signature = tensor_spec.TensorSpec((), dtypes.int64) dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, output_signature) iterator = iter(dataset) next(iterator) with self.assertRaises(errors.OutOfRangeError): fn() self.assertEqual(queue.size().numpy(), 2) @combinations.generate(test_base.eager_only_combinations()) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(iterator): trace_count[0] += 1 counter = np.int64(0) for elem in iterator: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(iter(dataset))), 10) self.assertEqual(self.evaluate(f(iter(dataset2))), 45) self.assertEqual(trace_count[0], 1) @combinations.generate(test_base.eager_only_combinations()) def testNestedFunctionsIteratorResource(self): @def_function.function def sum_dataset(ds): it = iter(ds) @def_function.function def next_element(it): return next(it) total = 0 for _ in range(10): total += next_element(it) return total ds = dataset_ops.Dataset.range(10) self.assertEqual(sum_dataset(ds).numpy(), 45) self.assertEqual(sum_dataset(ds).numpy(), 45) @combinations.generate(test_base.default_test_combinations()) def testNestedAutomaticControlDependencies(self): counter_var = variables.Variable(0) def map_fn(x): counter_var.assign_add(1) return x def dataset_fn(): return dataset_ops.Dataset.range(10).map(map_fn) @def_function.function def fn(): it = iter(dataset_fn()) for _ in range(10): _ = next(it) return counter_var self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10)
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_parallel_calls=[None, 1, 2], num_parallel_batches=None) + combinations.combine(num_parallel_calls=None, num_parallel_batches=10))) def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) def dataset_fn(batch_size, count): dataset = dataset_ops.Dataset.from_tensor_slices( components).repeat(count).apply( batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) return dataset # Batch of a finite input, where the batch_size divides the # total number of elements. dataset = dataset_fn(14, 28) get_next = self.getNext(dataset) self.assertEqual([[None] + list(c.shape[1:]) for c in components], [ shape.as_list() for shape in dataset_ops.get_legacy_output_shapes(dataset) ]) num_batches = (28 * 7) // 14 for i in range(num_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(14): self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Batch of a finite input, where the batch_size does not # divide the total number of elements. get_next = self.getNext(dataset_fn(8, 14)) # We expect (num_batches - 1) full-sized batches. num_batches = int(math.ceil((14 * 7) / 8)) for i in range(num_batches - 1): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(8): self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): self.assertAllEqual( component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Batch of an empty input should fail straight away. self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[]) # Empty batch should be an initialization time error. with self.assertRaises(errors.InvalidArgumentError): self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[]) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testMapAndBatchPartialBatch(self, drop_remainder): dataset = (dataset_ops.Dataset.range(10).apply( batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), batch_size=4, drop_remainder=drop_remainder))) if drop_remainder: self.assertEqual( [4, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) else: self.assertEqual( [None, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]]] if not drop_remainder: expected_output.append([[64], [81]]) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchYieldsPartialBatch(self): dataset = (dataset_ops.Dataset.range(10).apply( batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4))) self.assertEqual( [None, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]], [[64], [81]]] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchParallelGetNext(self): dataset = dataset_ops.Dataset.range(50000).apply( batching.map_and_batch(lambda x: x, batch_size=100)) if context.executing_eagerly(): iterator = iter(dataset) get_next = iterator._next_internal # pylint: disable=protected-access else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next elements = [] for _ in range(100): elements.append(get_next) for i in range(5): got = self.evaluate([element() for element in elements]) got.sort(key=lambda x: x[0]) expected = [] for j in range(100): expected.append( range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate([element() for element in elements]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchParallelGetNextDropRemainder(self): dataset = dataset_ops.Dataset.range(49999).apply( batching.map_and_batch(lambda x: x, batch_size=100, drop_remainder=True)) if context.executing_eagerly(): iterator = iter(dataset) get_next = iterator._next_internal # pylint: disable=protected-access else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next elements = [] for _ in range(100): elements.append(get_next) for i in range(4): got = self.evaluate([element() for element in elements]) got.sort(key=lambda x: x[0]) expected = [] for j in range(100): expected.append( range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate([element() for element in elements]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue(indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).apply( batching.map_and_batch(_sparse, 5)) self.assertDatasetProduces( dataset, expected_output=[ sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], dense_shape=[5, 1]) for i in range(2) ]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" with self.assertRaisesRegex(errors.InvalidArgumentError, "oops"): dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) dataset = dataset.apply(batching.map_and_batch(lambda x: x, 14)) get_next = self.getNext(dataset, requires_initialization=True) self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): yield [1] yield [2] yield [3] yield [[4, 5, 6]] dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int32) batch_size = 4 dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) self.assertDatasetProduces( dataset, expected_error=(errors.InvalidArgumentError, "number of elements does not match")) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchImplicitDispose(self): # Tests whether a map and batch dataset will be cleaned up correctly when # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> # MapAndBatchDataset(f=square_3, batch_size=100). components = (np.arange(1000), np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], np.array(37.0) * np.arange(1000)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) dataset = dataset.prefetch(5) get_next = self.getNext(dataset) for _ in range(3): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(threshold=[0, 5, 10, 90, 95, 99]))) def testMapAndBatchMapError(self, threshold): def raising_py_fn(i): if i >= threshold: raise StopIteration() else: return i dataset = dataset_ops.Dataset.range(100).apply( batching.map_and_batch( lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), batch_size=10)) get_next = self.getNext(dataset) for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], self.evaluate(get_next())) for i in range(threshold // 10, 10): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(element=False, dtype=dtypes.bool) + combinations.combine( element=-42, dtype=[dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64 ]) + combinations.combine( element=42, dtype=[dtypes.uint8, dtypes.uint16]) + combinations.combine( element=42.0, dtype=[dtypes.float16, dtypes.float32, dtypes.float64]) + combinations.combine(element=b"hello", dtype=[dtypes.string]))) def testMapAndBatchTypes(self, element, dtype): def gen(): yield element dataset = dataset_ops.Dataset.from_generator( gen, dtype).repeat(100).apply( batching.map_and_batch(lambda x: x, batch_size=10)) get_next = self.getNext(dataset) for _ in range(10): self.assertAllEqual([element for _ in range(10)], self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitIdentity(self): map_fn = lambda x: x dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( self.evaluate(self.structuredElement(None, shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitReplicate(self): map_fn = lambda x: (x, x) dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( self.evaluate(self.structuredElement(None, shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitSwap(self): map_fn = lambda x, y: (y, x) dataset = self.structuredDataset((None, None)).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( *self.evaluate(self.structuredElement((None, None), shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitProject(self): map_fn = lambda x, y: x dataset = self.structuredDataset((None, None)).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( *self.evaluate(self.structuredElement((None, None), shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitCapturedInput(self): captured_t = variables.Variable(42) dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(lambda x: captured_t, batch_size=10)) self.evaluate(variables.global_variables_initializer()) get_next = self.getNext(dataset, requires_initialization=True) self.assertAllEqual([42] * 10, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchControlFlow(self): def map_fn(x): previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x) control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value return return_value dataset = dataset_ops.Dataset.range(100).apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) for i in range(10): if i < 5: self.assertAllEqual([i * 10 + j + 1 for j in range(10)], self.evaluate(get_next())) else: self.assertAllEqual([((i * 10) + j) * ((i * 10) + j) for j in range(10)], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())
def keras_model_type_combinations(): return combinations.combine(model_type=KERAS_MODEL_TYPES)
class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): self.error = None self.repeat_count = 2 # Set up threading events used to sequence when items are produced that # are subsequently interleaved. These events allow us to deterministically # simulate slowdowns and force sloppiness. self.read_coordination_events = {} self.write_coordination_events = {} # input values [4, 5, 6] are the common case for the tests; set defaults for i in range(4, 7): self.read_coordination_events[i] = threading.Semaphore(0) self.write_coordination_events[i] = threading.Event() def dataset_fn(self, input_values, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): def map_py_fn(x): self.write_coordination_events[x].wait() self.write_coordination_events[x].clear() self.read_coordination_events[x].release() if self.error: err = self.error self.error = None raise err # pylint: disable=raising-bad-type return x * x def map_fn(x): return script_ops.py_func(map_py_fn, [x], x.dtype) def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) dataset = dataset.repeat(x) return dataset.map(map_fn) return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( self.repeat_count).apply( interleave_ops.parallel_interleave(interleave_fn, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements)) def _interleave(self, lists, cycle_length, block_length): """Python implementation of interleave used for testing.""" num_open = 0 # `all_iterators` acts as a queue of iterators over each element of `lists`. all_iterators = [iter(l) for l in lists] # `open_iterators` are the iterators whose elements are currently being # interleaved. open_iterators = [] for i in range(cycle_length): if all_iterators: open_iterators.append(all_iterators.pop(0)) num_open += 1 else: open_iterators.append(None) while num_open or all_iterators: for i in range(cycle_length): if open_iterators[i] is None: if all_iterators: open_iterators[i] = all_iterators.pop(0) num_open += 1 else: continue for _ in range(block_length): try: yield next(open_iterators[i]) except StopIteration: open_iterators[i] = None num_open -= 1 break @combinations.generate( combinations.times( combinations.combine( input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [ 6, 6, 6, 6, 6, 6 ], [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]], expected_elements=[[ 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6 ]], cycle_length=1, block_length=1) + combinations.combine( input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [ 6, 6, 6, 6, 6, 6 ], [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]], expected_elements=[[ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6 ]], cycle_length=2, block_length=1) + combinations.combine( input_lists=[[[4] * 4, [5] * 5, [6] * 6] * 2], expected_elements=[[ 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, 5, 6, 6, 5, 6, 6 ]], cycle_length=2, block_length=2) + combinations.combine( input_lists=[[[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6]]], expected_elements=[[ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 ]], cycle_length=2, block_length=1))) def testPythonImplementation(self, input_lists, expected_elements, cycle_length, block_length): for index, (expected, produced) in enumerate( zip_longest( expected_elements, self._interleave(input_lists, cycle_length, block_length))): self.assertEqual( expected, produced, "Values differ at %s. %s != %s" % (index, expected, produced)) def _clear_coordination_events(self): for i in range(4, 7): self.read_coordination_events[i] = threading.Semaphore(0) self.write_coordination_events[i].clear() def _allow_all_map_threads(self): for i in range(4, 7): self.write_coordination_events[i].set() @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sloppy=[False, True], prefetch_input_elements=[0, 1]))) def testSingleThreaded(self, sloppy, prefetch_input_elements): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. self.skipTest("b/131722904") self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=1, block_length=1, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=prefetch_input_elements)) for expected_element in self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): self.write_coordination_events[expected_element].set() self.assertEqual(expected_element * expected_element, self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. self.skipTest("b/131722904") self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([3, 7, 4]), cycle_length=2, block_length=1, sloppy=False, buffer_output_elements=1, prefetch_input_elements=1)) # Add coordination values for 3 and 7 self.read_coordination_events[3] = threading.Semaphore(0) self.write_coordination_events[3] = threading.Event() self.read_coordination_events[7] = threading.Semaphore(0) self.write_coordination_events[7] = threading.Event() for expected_element in self._interleave([[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): self.write_coordination_events[expected_element].set() output = self.evaluate(next_element()) self.assertEqual(expected_element * expected_element, output) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testTwoThreadsNoContention(self, sloppy): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=1)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 1)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() actual_element = self.evaluate(next_element()) if not done_first_event: self.read_coordination_events[expected_element].acquire() done_first_event = True self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testTwoThreadsNoContentionWithRaces(self, sloppy): """Tests where all the workers race in producing elements. Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. Args: sloppy: Whether to be sloppy or not. """ self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=1)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 1)): if done_first_event: # First event starts the worker threads. self._allow_all_map_threads() self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() time.sleep( 0.5) # Sleep to consistently "avoid" the race condition. actual_element = self.evaluate(next_element()) if not done_first_event: done_first_event = True self.assertTrue( self.read_coordination_events[expected_element].acquire( False)) self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testTwoThreadsNoContentionBlockLength(self, sloppy): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=2, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=1)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 2)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() actual_element = self.evaluate(next_element()) if not done_first_event: done_first_event = True self.read_coordination_events[expected_element].acquire() self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy): """Tests where all the workers race in producing elements. Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. Args: sloppy: Whether to be sloppy or not. """ self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=2, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=1)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 2)): if done_first_event: # First event starts the worker threads. self._allow_all_map_threads() self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() time.sleep( 0.5) # Sleep to consistently "avoid" the race condition. actual_element = self.evaluate(next_element()) if not done_first_event: done_first_event = True self.assertTrue( self.read_coordination_events[expected_element].acquire( False)) self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testEmptyInput(self, sloppy): # Empty input. self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([]), cycle_length=2, block_length=3, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=0)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def _testNonEmptyInputIntoEmptyOutputs(self, sloppy): # Non-empty input leading to empty output. self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([0, 0, 0]), cycle_length=2, block_length=3, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=0)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sloppy=[False, True], prefetch_input_elements=[1, 0]))) def testPartiallyEmptyOutputs(self, sloppy, prefetch_input_elements): race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 0, 6]), cycle_length=2, block_length=1, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=prefetch_input_elements)) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): self.write_coordination_events[expected_element].set() # First event starts the worker threads. Additionally, when running the # sloppy case with prefetch_input_elements=0, we get stuck if we wait # for the read coordination event for certain event orderings in the # presence of finishing iterators. if done_first_event and not (sloppy and (i in race_indices)): self.read_coordination_events[expected_element].acquire() actual_element = self.evaluate(next_element()) if not done_first_event or (sloppy and (i in race_indices)): done_first_event = True self.read_coordination_events[expected_element].acquire() self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) @combinations.generate(test_base.default_test_combinations()) def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. self.skipTest("b/131722904") self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=True, buffer_output_elements=1, prefetch_input_elements=0)) mis_ordering = [ 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6 ] for element in mis_ordering: self.write_coordination_events[element].set() self.assertEqual(element * element, self.evaluate(next_element())) self.assertTrue( self.read_coordination_events[element].acquire(False)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testBlockLengthWithContentionSloppy(self): self.skipTest("b/131722904") self._clear_coordination_events() done_first_event = False next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=True, buffer_output_elements=1, prefetch_input_elements=1)) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, cycle_length=2, block_length=3)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() actual_element = self.evaluate(next_element()) if not done_first_event: self.read_coordination_events[expected_element].acquire() done_first_event = True self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testEarlyExit(self, sloppy): # Exiting without consuming all input should not block self.skipTest("b/131722904") self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=3, block_length=2, sloppy=sloppy, buffer_output_elements=1, prefetch_input_elements=0)) for i in range(4, 7): self.write_coordination_events[i].set() elem = self.evaluate(next_element()) # Start all workers # Allow the one successful worker to progress beyond the py_func again. elem = int(math.sqrt(elem)) self.write_coordination_events[elem].set() self.read_coordination_events[elem].acquire() # Allow the prefetch to succeed for i in range(4, 7): self.read_coordination_events[i].acquire() self.write_coordination_events[i].set() @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(sloppy=[False, True]))) def testTooManyReaders(self, sloppy=False): def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) return dataset dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) dataset = dataset.repeat(self.repeat_count) dataset = dataset.apply( interleave_ops.parallel_interleave(interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) get_next = self.getNext(dataset) output_values = [] for _ in range(30): output_values.append(self.evaluate(get_next())) expected_values = self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) self.assertCountEqual(output_values, expected_values) @combinations.generate(test_base.default_test_combinations()) def testSparse(self): def _map_fn(i): return sparse_tensor.SparseTensor(indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) dataset = dataset_ops.Dataset.range(10).map(_map_fn).apply( interleave_ops.parallel_interleave(_interleave_fn, cycle_length=1)) get_next = self.getNext(dataset) for i in range(10): for j in range(2): expected = [i, 0] if j % 2 == 0 else [0, -i] self.assertAllEqual(expected, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testErrorsInOutputFn(self): self.skipTest("b/131722904") self._clear_coordination_events() next_element = self.getNext( self.dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=False, buffer_output_elements=1, prefetch_input_elements=0)) except_on_element_indices = set([3]) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 1)): if i in except_on_element_indices: self.error = ValueError() self.write_coordination_events[expected_element].set() with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_element()) else: self.write_coordination_events[expected_element].set() actual_element = self.evaluate(next_element()) self.assertEqual( expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testErrorsInInputFn(self): def map_py_fn(x): if x == 5: raise ValueError() return x def map_fn(x): return script_ops.py_func(map_py_fn, [x], x.dtype) def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) dataset = dataset.repeat(x) return dataset def dataset_fn(input_values, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): return dataset_ops.Dataset.from_tensor_slices(input_values).map( map_fn).repeat(self.repeat_count).apply( interleave_ops.parallel_interleave( interleave_fn, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements)) next_element = self.getNext( dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=False, buffer_output_elements=1, prefetch_input_elements=0)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): if expected_element == 5: with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_element()) else: actual_element = self.evaluate(next_element()) self.assertEqual( expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testErrorsInInterleaveFn(self): def map_py_fn(x): if x == 5: raise ValueError() return x def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) y = script_ops.py_func(map_py_fn, [x], x.dtype) dataset = dataset.repeat(y) return dataset def dataset_fn(input_values, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( self.repeat_count).apply( interleave_ops.parallel_interleave( interleave_fn, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements)) next_element = self.getNext( dataset_fn(input_values=np.int64([4, 5, 6]), cycle_length=2, block_length=1, sloppy=False, buffer_output_elements=1, prefetch_input_elements=0)) for i, expected_element in enumerate( self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): if expected_element == 5: with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_element()) else: actual_element = self.evaluate(next_element()) self.assertEqual( expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testShutdownRace(self): dataset = dataset_ops.Dataset.range(20) map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) dataset = dataset.apply( interleave_ops.parallel_interleave(map_fn, cycle_length=3, sloppy=False, buffer_output_elements=1, prefetch_input_elements=0)) dataset = dataset.batch(32) results = [] for _ in range(2): elements = [] next_element = self.getNext(dataset) try: while True: elements.extend(self.evaluate(next_element())) except errors.OutOfRangeError: pass results.append(elements) self.assertAllEqual(results[0], results[1]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sloppy=[None, True, False], global_determinism=[True, False]))) def testDeterminismConfiguration(self, sloppy, global_determinism): if sloppy is None: expect_determinism = global_determinism else: expect_determinism = not sloppy elements = list(range(1000)) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds dataset = dataset_ops.Dataset.from_tensor_slices(elements) dataset = dataset.apply( interleave_ops.parallel_interleave(interleave_fn, cycle_length=10, sloppy=sloppy)) opts = dataset_ops.Options() opts.experimental_deterministic = global_determinism dataset = dataset.with_options(opts) return dataset self.checkDeterminism(dataset_fn, expect_determinism, elements)
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(initial_known=[True, False]))) def testDistribution(self, initial_known): classes = np.random.randint(5, size=(10000, )) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] initial_dist = [0.2] * 5 if initial_known else None classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build. dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() get_next = self.getNext( dataset.rejection_resample(target_dist=target_dist, initial_dist=initial_dist, class_func=lambda c, _: c, seed=27)) returned = [] while len(returned) < 2000: returned.append(self.evaluate(get_next())) returned_classes, returned_classes_and_data = zip(*returned) _, returned_data = zip(*returned_classes_and_data) self.assertAllEqual( [compat.as_bytes(str(c)) for c in returned_classes], returned_data) total_returned = len(returned_classes) class_counts = np.array([ len([True for v in returned_classes if v == c]) for c in range(5) ]) returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(only_initial_dist=[True, False]))) def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): init_dist = [0.5, 0.5] target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] num_classes = len(init_dist) # We don't need many samples to test that this works. num_samples = 100 data_np = np.random.choice(num_classes, num_samples, p=init_dist) dataset = dataset_ops.Dataset.from_tensor_slices(data_np) # Reshape distribution. dataset = dataset.rejection_resample(class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) get_next = self.getNext(dataset) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: returned.append(self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testRandomClasses(self): init_dist = [0.25, 0.25, 0.25, 0.25] target_dist = [0.0, 0.0, 0.0, 1.0] num_classes = len(init_dist) # We don't need many samples to test a dirac-delta target distribution. num_samples = 100 data_np = np.random.choice(num_classes, num_samples, p=init_dist) dataset = dataset_ops.Dataset.from_tensor_slices(data_np) # Apply a random mapping that preserves the data distribution. def _remap_fn(_): return math_ops.cast( random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0] dataset = dataset.map(_remap_fn) # Reshape distribution. dataset = dataset.rejection_resample(class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) get_next = self.getNext(dataset) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: returned.append(self.evaluate(get_next())) classes, _ = zip(*returned) bincount = np.bincount(np.array(classes), minlength=num_classes).astype( np.float32) / len(classes) self.assertAllClose(target_dist, bincount, atol=1e-2) @combinations.generate(test_base.default_test_combinations()) def testExhaustion(self): init_dist = [0.5, 0.5] target_dist = [0.9, 0.1] dataset = dataset_ops.Dataset.range(10000) dataset = dataset.rejection_resample(class_func=lambda x: x % 2, target_dist=target_dist, initial_dist=init_dist) get_next = self.getNext(dataset) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: returned.append(self.evaluate(get_next())) classes, _ = zip(*returned) bincount = np.bincount(np.array(classes), minlength=len(init_dist)).astype( np.float32) / len(classes) self.assertAllClose(target_dist, bincount, atol=1e-2) @parameterized.parameters( ("float32", "float64"), ("float64", "float32"), ("float64", "float64"), ("float64", None), ) def testOtherDtypes(self, target_dtype, init_dtype): target_dist = np.array([0.5, 0.5], dtype=target_dtype) if init_dtype is None: init_dist = None else: init_dist = np.array([0.5, 0.5], dtype=init_dtype) dataset = dataset_ops.Dataset.range(10) dataset = dataset.rejection_resample(class_func=lambda x: x % 2, target_dist=target_dist, initial_dist=init_dist) get_next = self.getNext(dataset) self.evaluate(get_next())
def all_cluster_configurations(): with_work_dir = combinations.combine(work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False]) without_work_dir = combinations.combine(work_dir=NO_WORK_DIR, fault_tolerant_mode=False) return with_work_dir + without_work_dir
class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase): def __init__(self, methodName="runTest"): # pylint: disable=invalid-name super(LocalReplicateTest, self).__init__(methodName) self._device0 = "/device:CPU:0" self._device1 = "/device:CPU:1" self._device2 = "/device:CPU:2" @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph", "eager"])) def testBasic(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(100)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(100)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(100)) @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph", "eager"])) def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) # We don't support stateful ops in functions as of now. with self.assertRaises(errors.FailedPreconditionError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) self.evaluate(replicated_ds[self._device1]._variant_tensor) @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph", "eager"])) def testAllowStatefulOp(self): with compat.forward_compatibility_horizon(2019, 9, 12): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_allow_stateful = True dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): get_next0() get_next1() get_next2()
def reduce_fn(x, y): name, dataset_fn, expected_output = y return x + combinations.combine( dataset_fn=combinations.NamedObject(name, dataset_fn), expected_output=[expected_output])
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(count=[32, 34], padded_shapes=[[None], [25]], drop_remainder=[True, False]))) def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder): seq_lens = np.random.randint(20, size=(count, )).astype(np.int32) batch_size = 4 dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map( lambda x: array_ops.fill([x], x)).padded_batch( batch_size=batch_size, drop_remainder=drop_remainder, padded_shapes=padded_shapes) num_full_batches = len(seq_lens) // batch_size get_next = self.getNext(dataset) for i in range(num_full_batches): result = self.evaluate(get_next()) padded_len = padded_shapes[0] if padded_len is None or padded_len == -1: padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((batch_size, padded_len), result.shape) for j in range(batch_size): seq_len = seq_lens[(i * batch_size) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) if not drop_remainder and len(seq_lens) % batch_size > 0: result = self.evaluate(get_next()) padded_len = padded_shapes[0] if padded_len is None or padded_len == -1: padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape) for j in range(len(seq_lens) % batch_size): seq_len = seq_lens[num_full_batches * batch_size + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShortPadding(self): dataset = (dataset_ops.Dataset.from_tensor_slices( [6, 5, 5, 5, 5]).map(lambda x: array_ops.fill([x], x)).padded_batch( batch_size=4, padded_shapes=[5])) self.assertDatasetProduces(dataset, expected_error=(errors.DataLossError, '')) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchEmptyTensors(self): dataset = (dataset_ops.Dataset.from_tensor_slices( [0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch( batch_size=4, padded_shapes=[-1])) self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(padding_values=[(-1, '<end>', { 'structure': '' }), (-1, '<end>', None)]))) def testPaddedBatchDatasetNonDefaultPadding(self, padding_values): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled), { 'structure': string_ops.as_string(filled) }) random_seq_lens = np.random.randint(20, size=(32, )).astype(np.int32) dataset = (dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map( fill_tuple).padded_batch(4, padded_shapes=([-1], [-1], { 'structure': [-1] }), padding_values=padding_values)) get_next = self.getNext(dataset) for i in range(8): result = self.evaluate(get_next()) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) self.assertEqual((4, padded_len), result[2]['structure'].shape) for j in range(4): seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) self.assertAllEqual(result[1][j, :seq_len], [compat.as_bytes(str(seq_len))] * seq_len) self.assertAllEqual(result[1][j, seq_len:], [b'<end>'] * (padded_len - seq_len)) self.assertAllEqual(result[2]['structure'][j, :seq_len], [compat.as_bytes(str(seq_len))] * seq_len) self.assertAllEqual(result[2]['structure'][j, seq_len:], [b''] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchDatasetUnicode(self): # See GitHub issue 16149 def generator(): data = [[u'Простой', u'тест', u'юникода'], [u'никогда', u'не', u'бывает', u'простым']] for seq in data: yield seq, [0, 1, 2, 3] dataset = dataset_ops.Dataset.from_generator( generator, (dtypes.string, dtypes.int32), (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None ]))) padded_dataset = dataset.padded_batch(2, padded_shapes=([None], [None]), padding_values=('', 0)) next_element = self.getNext(padded_dataset) self.evaluate(next_element()) @combinations.generate(test_base.graph_only_combinations()) def testPaddedBatchDatasetShapeSpecifications(self): int_placeholder = array_ops.placeholder(dtypes.int32) float_placeholder = array_ops.placeholder(dtypes.float32) string_placeholder = array_ops.placeholder(dtypes.string) input_dataset = dataset_ops.Dataset.from_tensors( (int_placeholder, float_placeholder, string_placeholder)) # Test different ways of specifying the `padded_shapes` argument. dynamic_padding_from_tensor_shapes = input_dataset.padded_batch( 32, padded_shapes=(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None, None]), tensor_shape.TensorShape([37]))) dynamic_padding_from_lists = input_dataset.padded_batch( 32, padded_shapes=([None], [None, None], [37])) dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch( 32, padded_shapes=([-1], [-1, -1], [37])) dynamic_padding_from_tensors = input_dataset.padded_batch( 32, padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64), constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) for dataset in [ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors ]: dataset_output_shapes = dataset_ops.get_legacy_output_shapes( dataset) self.assertEqual([None, None], dataset_output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset_output_shapes[1].as_list()) self.assertEqual([None, 37], dataset_output_shapes[2].as_list()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchSparseError(self): def _map_fn(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i with self.assertRaises(TypeError): _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeError(self): with self.assertRaisesRegexp( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(\).'): _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1]) with self.assertRaisesRegexp( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(3,\).'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=[1]) with self.assertRaisesRegexp( ValueError, r'Padded shape .* must be a 1-D tensor ' r'of tf.int64 values, but its shape was \(2, 2\).'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=[[1, 1], [1, 1]]) with self.assertRaisesRegexp( TypeError, r'Padded shape .* must be a 1-D tensor ' r'of tf.int64 values, but its element type was float32.'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=constant_op.constant([1.5, 2., 3.])) with self.assertRaisesRegexp( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(\).'): shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64) _ = dataset_ops.Dataset.range(10).padded_batch( 5, padded_shapes=shape_as_tensor) @combinations.generate(test_base.graph_only_combinations()) def testPaddedBatchShapeErrorPlaceholder(self): with self.assertRaisesRegexp( ValueError, r'The padded shape \((\?|None), (\?|None)\) is not compatible with the ' r'corresponding input component shape \(\).'): shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2]) _ = dataset_ops.Dataset.range(10).padded_batch( 5, padded_shapes=shape_as_tensor)
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range( 10).map(lambda _: random_ops.random_uniform([])).batch(10) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) get_next = self.getNext(dataset) self.evaluate(get_next()) # TODO(b/123902160) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensor(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) dataset = dataset_ops.Dataset.from_tensors(input_t) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) self.evaluate(get_next) # TODO(b/123902160) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensorSlices(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) dataset = dataset_ops.Dataset.from_tensor_slices(input_t) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) self.evaluate(get_next) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDataset(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # Should be removed by noop elimination dataset = dataset.cache() return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDatasetWithModifiedRetval(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) # Should be fused by map and batch fusion dataset = dataset.map(lambda x: x) dataset = dataset.batch(1) return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_batch_fusion = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), _disable_intra_op_parallelism_test_combinations(), combinations.combine(apply_autotune=[None, True, False]))) def testOptimizationDisableIntraOpParallelism(self, dataset_fn, expected_output, apply_autotune): dataset = dataset_fn() dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"])) if apply_autotune is not None: options = dataset_ops.Options() options.experimental_optimization.autotune = apply_autotune dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(autotune=False, autotune_buffers=False) + combinations.combine(autotune=True, autotune_buffers=False) + combinations.combine(autotune=True, autotune_buffers=True), combinations.combine(first_buffer_sizes=[(1, -1, -1, 4), (2, -1, 3, -1), (2, 1, -1, -1)]), combinations.combine(second_buffer_sizes=[(1, -1, -1, 4), (2, -1, 3, -1), (2, 1, -1, -1)])) ) def testOptimizationAutotuneBuffers(self, autotune, autotune_buffers, first_buffer_sizes, second_buffer_sizes): dataset = dataset_ops.Dataset.range(10) for buffer_size in first_buffer_sizes: dataset = dataset.prefetch(buffer_size=buffer_size) dataset = dataset.map(lambda x: x + 1) for buffer_size in second_buffer_sizes: dataset = dataset.prefetch(buffer_size=buffer_size) options = dataset_ops.Options() options.experimental_optimization.autotune = autotune options.experimental_optimization.autotune_buffers = autotune_buffers dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=list(range(1, 11))) @combinations.generate(test_base.default_test_combinations()) def testOptimizationThreadPoolDataset(self): dataset = dataset_ops.Dataset.range(10).batch(10) dataset = threadpool.override_threadpool( dataset, threadpool.PrivateThreadPool( 2, display_name="private_thread_pool_%d" % 2)) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) self.assertDatasetProduces( dataset, expected_output=[list(range(10))], requires_initialization=True) # Reference variables are not supported in eager mode. @combinations.generate( combinations.times(test_base.graph_only_combinations(), _captured_refvar_test_combinations())) def testOptimizationWithCapturedRefVar(self, dataset_fn): """Tests that default optimizations are disabled with ref variables.""" variable = variable_scope.get_variable( "v", initializer=0, use_resource=False) assign_op = variable.assign_add(1) # Check that warning is logged. warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: unoptimized_dataset = dataset_fn(variable) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True options.experimental_optimization.map_and_batch_fusion = True optimized_dataset = unoptimized_dataset.with_options(options) optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset) self.assertGreaterEqual(len(w), 1) graph_rewrites = options._graph_rewrites() expected = ( "tf.data graph rewrites are not compatible with " "tf.Variable. The following rewrites will be disabled: %s." " To enable rewrites, use resource variables instead by " "calling `tf.enable_resource_variables()` at the start of the " "program." % (", ".join(graph_rewrites.enabled + graph_rewrites.default))) self.assertTrue(any(expected in str(warning) for warning in w)) # Check that outputs are the same in the optimized and unoptimized cases, # when the variable value is changing. unoptimized_it = dataset_ops.make_initializable_iterator( unoptimized_dataset) with ops.control_dependencies([assign_op]): unoptimized_output = unoptimized_it.get_next() optimized_output = optimized_it.get_next() self.evaluate(variable.initializer) self.evaluate((unoptimized_it.initializer, optimized_it.initializer)) while True: try: unoptimized, optimized = self.evaluate((unoptimized_output, optimized_output)) self.assertEqual(unoptimized, optimized) except errors.OutOfRangeError: break @combinations.generate(test_base.default_test_combinations()) def testOptimizationDefault(self): """Tests the optimization settings by default.""" options = dataset_ops.Options() expected_optimizations_enabled = [] expected_optimizations_disabled = [] expected_optimizations_default = [ "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion", ] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) options.experimental_optimization.apply_default_optimizations = True graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) options.experimental_optimization.apply_default_optimizations = False expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testOptimizationEnabled(self): """Tests the optimization settings by enabling all.""" options = dataset_ops.Options() options.experimental_optimization.filter_fusion = True options.experimental_optimization.filter_with_random_uniform_fusion = True options.experimental_optimization.hoist_random_uniform = True options.experimental_optimization.map_and_batch_fusion = True options.experimental_optimization.map_and_filter_fusion = True options.experimental_optimization.map_parallelization = True options.experimental_optimization.map_fusion = True options.experimental_optimization.noop_elimination = True options.experimental_optimization.parallel_batch = True options.experimental_optimization.shuffle_and_repeat_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.autotune_buffers = True options.experimental_deterministic = False options.experimental_stats.latency_all_edges = True options.experimental_slack = True expected_optimizations_enabled = [ "filter_fusion", "filter_with_random_uniform_fusion", "hoist_random_uniform", "map_and_batch_fusion", "map_and_filter_fusion", "map_parallelization", "map_fusion", "noop_elimination", "parallel_batch", "shuffle_and_repeat_fusion", "map_vectorization", "autotune_buffer_sizes", "make_sloppy", "latency_all_edges", "slack", "disable_prefetch_legacy_autotune", ] expected_optimizations_disabled = [] expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testOptimizationDisabled(self): """Tests the optimization settings by disabling all.""" options = dataset_ops.Options() options.experimental_optimization.filter_fusion = False options.experimental_optimization.filter_with_random_uniform_fusion = False options.experimental_optimization.hoist_random_uniform = False options.experimental_optimization.map_and_batch_fusion = False options.experimental_optimization.map_and_filter_fusion = False options.experimental_optimization.map_parallelization = False options.experimental_optimization.map_fusion = False options.experimental_optimization.noop_elimination = False options.experimental_optimization.parallel_batch = False options.experimental_optimization.shuffle_and_repeat_fusion = False options.experimental_optimization.map_vectorization.enabled = False options.experimental_optimization.autotune = False options.experimental_deterministic = True options.experimental_stats.latency_all_edges = False options.experimental_slack = False expected_optimizations_enabled = [] expected_optimizations_disabled = [ "filter_fusion", "filter_with_random_uniform_fusion", "hoist_random_uniform", "map_and_batch_fusion", "map_and_filter_fusion", "map_parallelization", "map_fusion", "noop_elimination", "parallel_batch", "shuffle_and_repeat_fusion", "map_vectorization", "autotune_buffer_sizes", "make_sloppy", "latency_all_edges", "slack", "disable_prefetch_legacy_autotune", ] expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testAutotuningDefaults(self): options = dataset_ops.Options() # Check defaults autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings() self.assertTrue(autotune) self.assertEqual(algorithm, optimization_options._AutotuneAlgorithm.HILL_CLIMB) self.assertEqual(cpu_budget, 0) self.assertEqual(ram_budget, 0) @combinations.generate(test_base.default_test_combinations()) def testAutotuningSettings(self): options = dataset_ops.Options() options.experimental_optimization.autotune_cpu_budget = 1000 options.experimental_optimization.autotune_ram_budget = 999999999 options.experimental_optimization.autotune_buffers = True self.assertIn("autotune_buffer_sizes", options._graph_rewrites().enabled) self.assertIn("disable_prefetch_legacy_autotune", options._graph_rewrites().enabled) autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings() self.assertTrue(autotune) self.assertEqual(algorithm, optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT) self.assertEqual(cpu_budget, 1000) self.assertEqual(ram_budget, 999999999)
class BucketBySequenceLengthTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucketDropReminder(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] n_bucket_elements = [28, 7, 6, 5] n_expected_batches = 5 # Expected sequence lengths of the individual batches. expected_lengths = [] # Expected sum of all batches with an equal sequence length. # <seq-length>: <expected-total-sum> expected_sums = {} # Expected batch sizes of batches depending on the sequence length. # <seq-length>: [batch1_size, ..., batchN_size] expected_batch_sizes = {} for length, batch_size, bucket_elements in zip(lengths, batch_sizes, n_bucket_elements): # Calculate the expected sum across all batches of a specific sequence # length. expected_sums[length] = \ (bucket_elements - bucket_elements % batch_size) * length # Calculate the expected occurrence of individual batch sizes. expected_batch_sizes[length] = \ [batch_size] * (bucket_elements // batch_size) # Calculate the expected occurrence of individual sequence lengths. expected_lengths.extend([length] * (bucket_elements // batch_size)) def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for bucket_elements, length in zip(n_bucket_elements, lengths): # Using only full sequences (opposed to the strategy employed in # `testBucket`) makes checking the sum a lot easier. record_len = length for _ in range(bucket_elements): elements.append([1] * record_len) random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_boundaries=boundaries, bucket_batch_sizes=batch_sizes, no_padding=no_padding, drop_remainder=True) get_next = self.getNext(dataset) batches = [] for _ in range(n_expected_batches): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) generated_lengths = [] # <seq-length>: <total-sum> generated_sums = {} # <seq-length>: [<batch_size>, ...] generated_batch_sizes = {} for length, batch_size, bucket_elements in zip( lengths, batch_sizes, n_bucket_elements): # Initialize the sum across all batches. generated_sums[length] = 0 # Initialize the individual batch sizes. generated_batch_sizes[length] = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape length = shape[1] generated_lengths.append(length) batch_size = shape[0] generated_batch_sizes[length].append(batch_size) batch_sum = batch.values.sum() if no_padding else batch.sum() generated_sums[length] += batch_sum for l in lengths: # Make sure the sum of the batch contents is correct for the individual # sequence lengths. self.assertEqual( generated_sums[l], expected_sums[l], "Tensor sums did not match! " "expected: {}, generated: {}".format( expected_sums, generated_sums)) # Make sure the individual batch sizes are generated as expected. self.assertEqual( sorted(generated_batch_sizes[l]), sorted(expected_batch_sizes[l]), "Batch-sizes did not match! " "expected: {}, generated: {}".format( sorted(expected_batch_sizes[l]), sorted(generated_batch_sizes[l]))) # Make sure the generated sequence lengths appear as often as expected. self.assertEqual( sorted(generated_lengths), sorted(expected_lengths), "The generated sequence lengths did not match! " "expected: {}, generated: {}".format( sorted(expected_lengths), sorted(generated_lengths))) _test_bucket_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucket(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes, lengths): record_len = length - 1 for _ in range(batch_size): elements.append([1] * record_len) record_len = length random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_boundaries=boundaries, bucket_batch_sizes=batch_sizes, no_padding=no_padding) get_next = self.getNext(dataset) batches = [] for _ in range(4): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape batch_size = shape[0] length = shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) if not context.executing_eagerly(): sum_check = batch.values.sum( ) if no_padding else batch.sum() self.assertEqual(sum_check, batch_size * length - 1) self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual(sorted(lengths), sorted(lengths_val)) _test_bucket_by_padding(param_no_padding) @combinations.generate(test_base.default_test_combinations()) def testPadToBoundary(self): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25] def element_gen(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes[:-1], lengths): for _ in range(batch_size): elements.append([1] * length) random.shuffle(elements) for el in elements: yield (el, ) for _ in range(batch_sizes[-1]): el = [1] * (boundaries[-1] + 5) yield (el, ) element_len = lambda el: array_ops.shape(el)[0] dataset = dataset_ops.Dataset.from_generator(element_gen, (dtypes.int64, ), ([None], )) dataset = dataset.bucket_by_sequence_length( element_length_func=element_len, bucket_boundaries=boundaries, bucket_batch_sizes=batch_sizes, pad_to_bucket_boundary=True) get_next = self.getNext(dataset) batches = [] for _ in range(3): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaisesOpError("bucket_boundaries"): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: batch_size = batch.shape[0] length = batch.shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], sorted(lengths_val)) @combinations.generate(test_base.default_test_combinations()) def testPadToBoundaryNoExtraneousPadding(self): boundaries = [3, 7, 11] batch_sizes = [2, 2, 2, 2] lengths = range(1, 11) def element_gen(): for length in lengths: yield ([1] * length, ) element_len = lambda element: array_ops.shape(element)[0] dataset = dataset_ops.Dataset.from_generator(element_gen, (dtypes.int64, ), ([None], )) dataset = dataset.bucket_by_sequence_length( element_length_func=element_len, bucket_boundaries=boundaries, bucket_batch_sizes=batch_sizes, pad_to_bucket_boundary=True) get_next = self.getNext(dataset) batches = [] for _ in range(5): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(batches[0], [[1, 0], [1, 1]]) self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]]) self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1]]) self.assertAllEqual( batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) self.assertAllEqual( batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testTupleElements(self, param_no_padding): def build_dataset(sparse): def _generator(): text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] label = [1, 2, 1, 2] for x, y in zip(text, label): yield (_format_record(x, sparse), y) dataset = dataset_ops.Dataset.from_generator( generator=_generator, output_types=(_get_record_type(sparse), dtypes.int32), output_shapes=(_get_record_shape(sparse), tensor_shape.TensorShape([]))) if sparse: dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) return dataset def _test_tuple_elements_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_batch_sizes=[2, 2, 2], bucket_boundaries=[0, 8], no_padding=no_padding) shapes = dataset_ops.get_legacy_output_shapes(dataset) self.assertEqual([None, None], shapes[0].as_list()) self.assertEqual([None], shapes[1].as_list()) _test_tuple_elements_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_drop_remainder=[True, False]))) def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args """Tests bucketing of sparse tensors (case where `no_padding` == True). Test runs on following dataset: [ [0], [0, 1], [0, 1, 2] ... [0, ..., max_len - 1] ] Sequences are bucketed by length and batched with `batch_size` < `bucket_size`. """ min_len = 0 max_len = 100 batch_size = 7 bucket_size = 10 def _build_dataset(): input_data = [range(i + 1) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) dataset = dataset.map(_to_sparse_tensor) return dataset def _compute_expected_batches(drop_remainder): """Computes expected batch outputs and stores in a set.""" all_expected_sparse_tensors = set() for bucket_start_len in range(min_len, max_len, bucket_size): if drop_remainder: batch_offsets = [0] else: batch_offsets = range(0, bucket_size, batch_size) for batch_offset in batch_offsets: batch_start_len = bucket_start_len + batch_offset batch_end_len = min(batch_start_len + batch_size, bucket_start_len + bucket_size) expected_indices = [] expected_values = [] for length in range(batch_start_len, batch_end_len): for val in range(length + 1): expected_indices.append( (length - batch_start_len, val)) expected_values.append(val) expected_sprs_tensor = (tuple(expected_indices), tuple(expected_values)) all_expected_sparse_tensors.add(expected_sprs_tensor) return all_expected_sparse_tensors def _compute_batches(dataset): """Computes actual batch outputs of dataset and stores in a set.""" batch = self.getNext(dataset) all_sparse_tensors = set() with self.assertRaises(errors.OutOfRangeError): while True: output = self.evaluate(batch()) sprs_tensor = (tuple([ tuple(idx) for idx in output.indices ]), tuple(output.values)) all_sparse_tensors.add(sprs_tensor) return all_sparse_tensors dataset = _build_dataset() boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_boundaries=boundaries, bucket_batch_sizes=[batch_size] * (len(boundaries) + 1), no_padding=True, drop_remainder=param_drop_remainder) batches = _compute_batches(dataset) expected_batches = _compute_expected_batches(param_drop_remainder) self.assertEqual(batches, expected_batches) @combinations.generate(test_base.default_test_combinations()) def testCardinality(self): boundaries = [3, 7, 11] batch_sizes = [2, 2, 2, 2] lengths = range(1, 11) def element_gen(): for length in lengths: yield ([1] * length, ) element_len = lambda element: array_ops.shape(element)[0] dataset = dataset_ops.Dataset.from_generator(element_gen, (dtypes.int64, ), ([None], )).repeat() dataset = dataset.bucket_by_sequence_length( element_length_func=element_len, bucket_boundaries=boundaries, bucket_batch_sizes=batch_sizes, pad_to_bucket_boundary=True) self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.INFINITE)
def reduce_fn(x, y): name, dataset_fn = y return x + combinations.combine( dataset_fn=combinations.NamedObject(name, dataset_fn))
class ReplicateClusterTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(ReplicateClusterTest, self).setUp() # Start the local server. worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 2 worker, _ = test_util.create_local_cluster(3, 0, worker_config=worker_config) self._device0 = "/job:worker/replica:0/task:0/device:CPU:0" self._device1 = "/job:worker/replica:0/task:1/device:CPU:0" self._device2 = "/job:worker/replica:0/task:2/device:CPU:0" self._target = worker[0].target @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) def testBasic(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) with session.Session(self._target) as sess: for i in range(100): self.assertEqual(i, sess.run(get_next())) self.assertEqual(i, sess.run(get_next1())) self.assertEqual(i, sess.run(get_next2())) @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) def testMap(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) with session.Session(self._target) as sess: for i in range(100): self.assertEqual(i * 2, sess.run(get_next())) self.assertEqual(i * 2, sess.run(get_next1())) self.assertEqual(i * 2, sess.run(get_next2())) @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable("counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] with ops.device(self._device1): it1 = dataset_ops.make_initializable_iterator(dataset1) # We don't support stateful ops across processes in functions as of now. with session.Session(self._target) as sess: with self.assertRaises(errors.OpError): sess.run(it1.initializer)
class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_elements=[14, 15], window_size=[2]) + combinations.combine(num_elements=[100], window_size=[3]))) def testTakeWhileDataset(self, num_elements, window_size): def _predicate_func(elem): return array_ops.shape(elem)[0] > (window_size - 1) dataset = dataset_ops.Dataset.range(num_elements).batch(window_size) dataset = dataset.take_while(predicate=_predicate_func).flat_map( dataset_ops.Dataset.from_tensor_slices) expected_num_elements = int(num_elements / window_size) * window_size self.assertDatasetProduces(dataset, np.arange(expected_num_elements)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_elements=[10], upper_bound=[2]) + combinations.combine(num_elements=[16], upper_bound=[7]) + combinations.combine(num_elements=[100], upper_bound=[99]) + combinations.combine(num_elements=[100], upper_bound=[101]) + combinations.combine(num_elements=[0], upper_bound=[1]))) def testTakeWhileDatasetRange(self, num_elements, upper_bound): dataset = dataset_ops.Dataset.range(num_elements).take_while( lambda x: x < upper_bound) self.assertDatasetProduces(dataset, np.arange(min(num_elements, upper_bound))) @combinations.generate(test_base.default_test_combinations()) def testTakeWhileDatasetString(self): def not_equal(string): return lambda x: math_ops.not_equal(x, constant_op.constant(string) ) string = ["this", "is", "the", "test", "for", "strings"] dataset = dataset_ops.Dataset.from_tensor_slices(string).take_while( predicate=not_equal("test")) next_element = self.getNext(dataset) self.assertEqual(b"this", self.evaluate(next_element())) self.assertEqual(b"is", self.evaluate(next_element())) self.assertEqual(b"the", self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(size=[5], index=[3]) + combinations.combine(size=[10], index=[0]) + combinations.combine(size=[100], index=[5]) + combinations.combine(size=[8], index=[7]))) def testTakewhileDatasetShortCircuit(self, size, index): def _predicate_func(data_elem): return data_elem boolean_array = [True] * size boolean_array[index] = False dataset = dataset_ops.Dataset.from_tensor_slices( boolean_array).take_while(predicate=_predicate_func) next_element = self.getNext(dataset) for _ in range(index): self.assertTrue(self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) @combinations.generate(test_base.default_test_combinations()) def testTakeWhileDatasetWithRepeat(self): dataset = dataset_ops.Dataset.range(10).take_while( predicate=lambda x: x < 2).repeat(5) self.assertDatasetProduces(dataset, np.tile([0, 1], 5)) @combinations.generate(test_base.default_test_combinations()) def testTakeWhileDatasetStops(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.take_while( lambda x: math_ops.logical_not(math_ops.equal(x, 5))) self.assertDatasetProduces(dataset, range(5))
class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # For this test we need to use `framework.test_combinations` because our # `generate` eats the cluster parameters. # # Note that we don't have a standalone combination for ClusterParameters, so # we should use GPUCombination which contains it. @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine(distribution=[ combinations.NamedDistribution( "HasClusterParams", lambda: None, has_chief=True, num_workers=2), ]), test_combinations=(combinations.ClusterCombination(),)) def testClusterParams(self, distribution, has_chief, num_workers): self.assertTrue(has_chief) self.assertEqual(num_workers, 2) @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine(distribution=[ combinations.NamedDistribution("NoClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine(v=1), test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsNoStrategy(self, v, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine(distribution=[ combinations.NamedDistribution( "WithClusterParams", lambda: None, has_chief=True, num_workers=2), combinations.NamedDistribution("WithoutClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(),)) def testClusterParamsAreOptional(self, distribution): # If combinations library doesn't raise an exception, the test is passed. pass @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine( ds1=combinations.NamedDistribution( "Strategy1", lambda: None, has_chief=True, num_workers=0), ds2=combinations.NamedDistribution( "Strategy2", lambda: None, has_chief=False, num_workers=1), ds3=combinations.NamedDistribution( "Strategy3", lambda: None, has_chief=True, num_workers=0), ), test_combinations=(combinations.ClusterCombination(),)) def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): # If combinations library doesn't raise an exception, the test is passed. pass @combinations.generate(combinations.combine(num_workers=2,)) def testUseWithoutStrategy(self): # There's no perfect way to check if the test runs in a subprocess. We # approximate by checking the presence of TF_CONFIG, which is normally not # set to the main process. self.assertNotEqual(os.getenv("TF_CONFIG"), "")
class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(make_dataset=[ _make_scalar_ds, _make_vector_ds, _make_matrix_ds1, _make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds, _make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined, ], nrows=[0, 20, 23], batch_size=[4], drop_remainder=[True, False]))) def testBasic(self, make_dataset, nrows, batch_size, drop_remainder): dataset = make_dataset(nrows) # Get the unbatched rows (so we can check expected values). get_next = self.getNext(dataset) rows = [ nest.map_structure(_to_list, self.evaluate(get_next())) for _ in range(nrows) ] # Batch the dataset, and check that batches match slices from `rows`. batched_dataset = dataset.apply( batching.dense_to_ragged_batch(batch_size, drop_remainder)) get_next = self.getNext(batched_dataset) for start_row in range(0, nrows, batch_size): end_row = start_row + batch_size if end_row > nrows and drop_remainder: break end_row = min(end_row, nrows) result = self.evaluate(get_next()) # Use nest for potentially nested datasets. nest.map_structure_up_to( result, lambda a, *b: self.assertAllEqual(a, list(b)), result, *rows[start_row:end_row]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testWithStructuredElements(self): nrows = 20 batch_size = 4 def make_structure(x): return { 'dense': array_ops.fill([x], x), 'ragged': ragged_concat_ops.stack( [array_ops.stack([x]), array_ops.stack([x, x])]), 'sparse': sparse_tensor.SparseTensor([[x]], [x], [100]) } dataset = dataset_ops.Dataset.from_tensor_slices(np.arange(nrows)) dataset = dataset.map(make_structure) dataset = dataset.apply(batching.dense_to_ragged_batch(batch_size)) get_next = self.getNext(dataset) for i in range(0, nrows, batch_size): result = self.evaluate(get_next()) rows = range(i, i + batch_size) self.assertAllEqual(result['dense'], [[r] * r for r in rows]) self.assertAllEqual(result['ragged'], [[[r], [r, r]] for r in rows]) self.assertAllEqual(result['sparse'].indices, list(enumerate(rows))) self.assertAllEqual(result['sparse'].values, rows) self.assertAllEqual(result['sparse'].dense_shape, [4, 100])