class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDispatcherStop(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBeforeReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpoch(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpochRepeat(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 repetitions = 5 breakpoints = [50, 250, 450, 500] ds = dataset_ops.Dataset.range(num_elements) ds = ds.repeat(repetitions) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for breakpoint in breakpoints: for _ in range(len(results), breakpoint): results.append(next(iterator).numpy()) cluster.restart_dispatcher() self.assertCountEqual(repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBetweenIterations(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherManyRestarts(self): cluster = self.create_cluster(num_workers=1) num_elements_start = 10 num_elements_end = 15 datasets = [] for num_elements in range(num_elements_start, num_elements_end): datasets.append( self.make_distributed_range_dataset(num_elements, cluster)) cluster.restart_dispatcher() for ds, num_elements in zip( datasets, range(num_elements_start, num_elements_end)): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndWorkerRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(workers_to_add=[1, 3]))) def testRoundRobinAddWorkers(self, workers_to_add): starting_workers = 3 cluster = self.create_cluster(num_workers=starting_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 7 ds = dataset_ops.Dataset.range(100000000) ds = ds.repeat() consumers = [] for consumer_index in range(num_consumers): consumers.append( self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) # Use parallel interleave to read from consumers in parallel. ds = dataset_ops.Dataset.from_tensor_slices(consumers) ds = ds.interleave(lambda x: x, cycle_length=num_consumers, num_parallel_calls=num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] zeros_seen = 0 for _ in range(50): results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 for _ in range(workers_to_add): cluster.add_worker() # Read until all new workers have joined. while zeros_seen < starting_workers + workers_to_add: results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 # Read some more. for _ in range(100): results.append(self.evaluate(get_next())) for i in range(0, len(results), num_consumers): self.assertEqual(0, results[i] % num_consumers) # Check that each group of `num_consumers` results are consecutive. for offset in range(1, num_consumers): if i + offset < len(results): self.assertEqual(results[i] + offset, results[i + offset]) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndMultiWorkerRestart(self): num_workers = 2 cluster = self.create_cluster(num_workers=num_workers) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.restart_worker(worker_index=worker_index) for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.restart_worker(worker_index=worker_index) for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = self.create_cluster(num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join() @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.add_worker() # Wait for the new worker to register with the dispatcher. while cluster.num_registered_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]), data_service_test_base.all_cluster_configurations()) ) def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.restart_worker(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) ds = range_dataset.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.target, job_name="test")) iterator = iter(ds) for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) cluster.restart_dispatcher() ds = range_dataset.apply( data_service_ops.distribute(processing_mode="distributed_epoch", service=cluster.target, job_name="test")) with self.assertRaisesOpError( "already an existing job with that name " "using processing mode <parallel_epochs>"): next(iter(ds)).numpy() @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR]))) def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster = self.create_cluster(num_workers=0, work_dir=work_dir, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) it = iter(ds) cluster.add_worker() self.assertAllEqual(next(it), tensor)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.eager_only_combinations(), data_service_test_base.all_cluster_configurations())) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = self.create_cluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(10, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = self.create_cluster(num_workers=1) element = sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = self.create_cluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): cluster = self.create_cluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @combinations.generate(test_base.eager_only_combinations()) def testRepeatedDataset(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces( ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_datasets = 3 iterators = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): self.skipTest("Not yet implemented") cluster = self.create_cluster(num_workers=1) num_elements = 10 num_iterators = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) result = [] iterators = [] for _ in range(num_iterators): iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, max_outstanding_requests=1) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds)) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray( dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle(num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = self.create_cluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.num_tasks_on_worker(), 1) del it while cluster.num_tasks_on_worker() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = self.create_cluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset( num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset( num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset( num_elements, cluster, job_name="test2")) self.assertEqual(2, cluster.num_tasks_on_worker()) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.num_tasks_on_worker() > 1: time.sleep(0.1) self.assertEqual(1, cluster.num_tasks_on_worker()) @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = self.create_cluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism( dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = self.create_cluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) next(iter(ds)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x]), num_parallel_calls=dataset_ops.AUTOTUNE) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat( num_repeats) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, num_repeats * list(range(num_elements)), assert_items_equal=True) def testDistributeFromInterleave(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute( processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute( processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute( processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetId(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdNotRegistered(self): cluster = self.create_cluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = self.create_cluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.eager_only_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.eager_only_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertNotEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.eager_only_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = self.create_cluster(num_workers=cluster_1_size) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) it = iter(ds) for _ in range(num_sizes): element = next(it).numpy() for _ in range(1, cluster_1_size): self.assertAllEqual(next(it).numpy(), element) self.assertEmpty(list(it)) @combinations.generate( combinations.times(test_base.eager_only_combinations())) def testDistributeLargeGraph(self): cluster = self.create_cluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.eager_only_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(10, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = self.create_cluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = self.create_cluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): cluster = self.create_cluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @combinations.generate(test_base.eager_only_combinations()) def testRepeatedDataset(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_datasets = 3 iterators = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): self.skipTest("Not yet implemented") cluster = self.create_cluster(num_workers=1) num_elements = 10 num_iterators = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) result = [] iterators = [] for _ in range(num_iterators): iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds)) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) def testRoundRobin(self, num_workers, num_consumers): cluster = self.create_cluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = dataset_ops.Dataset.range(10000000) ds = ds.repeat() consumers = [] for consumer_index in range(num_consumers): consumers.append( self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) # Use parallel interleave to read from consumers in parallel. ds = dataset_ops.Dataset.from_tensor_slices(consumers) ds = ds.interleave(lambda x: x, cycle_length=num_consumers, num_parallel_calls=num_consumers) ds = ds.take(1000) results = self.getDatasetOutput(ds, requires_initialization=True) for i in range(0, len(results), num_consumers): self.assertEqual(0, results[i] % num_consumers) # Check that each group of `num_consumers` results are consecutive. for offset in range(1, num_consumers): if i + offset < len(results): self.assertEqual(results[i] + offset, results[i + offset]) @combinations.generate(test_base.default_test_combinations()) def testRoundRobinBucketizing(self): # Tests a common use case for round robin reads. At each step, all # consumers should get batches with the same bucket size. cluster = self.create_cluster(num_workers=4) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_elements = 100 low_bucket_max = 30 mid_bucket_max = 60 bucket_boundaries = [low_bucket_max, mid_bucket_max] batch_size = 10 num_consumer_hosts = 3 replicas_per_consumer_host = 5 num_consumers = num_consumer_hosts * replicas_per_consumer_host bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) # Set up the dataset that will run on the tf.data workers. ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = ds.shuffle(num_elements) ds = ds.repeat() ds = ds.apply( grouping.bucket_by_sequence_length(lambda x: x, bucket_boundaries, bucket_batch_sizes, drop_remainder=True)) ds = ds.apply( grouping.group_by_window( lambda x: math_ops.cast(x[1], dtypes.int64), lambda _, x: dataset_ops.Dataset.from_tensors(x), window_size=num_consumers)) ds = ds.flat_map(lambda x: x) # Set up the per-consumer-host datasets. During each global step, we pull # `replicas_per_consumer_host` batches from each of these datasets. host_datasets = [] for host_index in range(num_consumer_hosts): per_replica_datasets = [] for i in range(replicas_per_consumer_host): consumer_index = host_index * replicas_per_consumer_host + i per_replica_datasets.append( self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) host_dataset = dataset_ops.Dataset.from_tensor_slices( per_replica_datasets) host_dataset = host_dataset.interleave( lambda x: x, cycle_length=len(per_replica_datasets), num_parallel_calls=len(per_replica_datasets), deterministic=True) host_datasets.append(host_dataset) # Use parallel interleave to read from host datasets in parallel. ds = dataset_ops.Dataset.from_tensor_slices(host_datasets) ds = ds.interleave(lambda x: x, block_length=replicas_per_consumer_host, cycle_length=len(host_datasets), num_parallel_calls=len(host_datasets), deterministic=True) num_rounds = 10 get_next = self.getNext(ds, requires_initialization=True) results = [] for _ in range(num_rounds * num_consumers): results.append(self.evaluate(get_next())) def get_bucket(elem): bucket_ind = 0 while bucket_ind < len(bucket_boundaries ) and elem >= bucket_boundaries[bucket_ind]: bucket_ind += 1 return bucket_ind # Check that the batches for each step contain elements from the same # bucket. for i in range(0, len(results), num_consumers): batches = results[num_consumers * i:num_consumers * (i + 1)] bucket_inds = [get_bucket(batch[0]) for batch in batches] for bucket_ind in bucket_inds[1:]: self.assertEqual(bucket_inds[0], bucket_ind) @combinations.generate(test_base.v1_only_combinations()) def testRoundRobinFiniteV1(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.v2_only_combinations()) def testRoundRobinFiniteV2(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Round robin reads " "require that the input dataset has infinite " "cardinality, but the dataset has cardinality " + str(num_elements)): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.num_tasks_on_worker(), 1) del it while cluster.num_tasks_on_worker() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(2, cluster.num_tasks_on_worker()) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.num_tasks_on_worker() > 1: time.sleep(0.1) self.assertEqual(1, cluster.num_tasks_on_worker()) @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = self.create_cluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = self.create_cluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) next(iter(ds)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x]), num_parallel_calls=dataset_ops.AUTOTUNE) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeat(self): cluster = self.create_cluster(num_workers=2) num_elements = 20 elements_to_read = 1000 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) results = {} for _ in range(elements_to_read): val = next(it).numpy() if val not in results: results[val] = 0 results[val] += 1 for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeatFewElements(self): num_workers = 5 cluster = self.create_cluster(num_workers=num_workers) # Less than the number of workers, so that some workers get zero elements on # the first repetition. num_elements = 1 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) for _ in range(100): self.assertEqual(next(it).numpy(), 0) # Stop all but one worker and check that we can still read. for i in range(num_workers - 1): cluster.workers[i]._stop() for _ in range(100): self.assertEqual(next(it).numpy(), 0) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle( num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) def testDistributeFromInterleave(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetId(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdNotRegistered(self): cluster = self.create_cluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = self.create_cluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.eager_only_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.eager_only_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertNotEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnZippedDataset(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2)) ds_3 = self.make_distributed_dataset( ds_3, cluster, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type ZipDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds_3, requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnDistributedDataset(self): cluster_1 = self.create_cluster(num_workers=1) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 numbers = [1 * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(numbers) ds = self.make_distributed_dataset(ds, cluster_1, processing_mode="parallel_epochs") ds = ds.map(lambda x: x + 1) ds = self.make_distributed_dataset(ds, cluster_2, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type DataServiceDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.eager_only_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = self.create_cluster(num_workers=cluster_1_size) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) it = iter(ds) for _ in range(num_sizes): element = next(it).numpy() for _ in range(1, cluster_1_size): self.assertAllEqual(next(it).numpy(), element) self.assertEmpty(list(it)) @combinations.generate( combinations.times(test_base.eager_only_combinations())) def testDistributeLargeGraph(self): cluster = self.create_cluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDispatcherStop(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBeforeReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringReading(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBetweenIterations(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherManyRestarts(self): cluster = self.create_cluster(num_workers=1) num_elements_start = 10 num_elements_end = 15 datasets = [] for num_elements in range(num_elements_start, num_elements_end): datasets.append( self.make_distributed_range_dataset(num_elements, cluster)) cluster.restart_dispatcher() for ds, num_elements in zip( datasets, range(num_elements_start, num_elements_end)): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndWorkerRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = self.create_cluster(num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join() @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.add_worker() # Wait for the new worker to register with the dispatcher. while cluster.num_registered_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]), data_service_test_base.all_cluster_configurations()) ) def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.restart_worker(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) ds = range_dataset.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.target, job_name="test")) iterator = iter(ds) for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) cluster.restart_dispatcher() ds = range_dataset.apply( data_service_ops.distribute(processing_mode="distributed_epoch", service=cluster.target, job_name="test")) with self.assertRaisesOpError( "already an existing job with that name " "using processing mode <parallel_epochs>"): next(iter(ds)).numpy() @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR]))) def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster = self.create_cluster(num_workers=0, work_dir=work_dir, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) it = iter(ds) cluster.add_worker() self.assertAllEqual(next(it), tensor)