class FaultToleranceTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDispatcherStop(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBeforeReading(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringReading(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpochRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 repetitions = 5 breakpoints = [50, 250, 450, 500] ds = dataset_ops.Dataset.range(num_elements) ds = ds.repeat(repetitions) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for breakpoint_ in breakpoints: for _ in range(len(results), breakpoint_): results.append(next(iterator).numpy()) cluster.restart_dispatcher() self.assertCountEqual(repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBetweenIterations(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherManyRestarts(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements_start = 10 num_elements_end = 15 datasets = [] for num_elements in range(num_elements_start, num_elements_end): datasets.append( self.make_distributed_range_dataset(num_elements, cluster)) cluster.restart_dispatcher() for ds, num_elements in zip( datasets, range(num_elements_start, num_elements_end)): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndWorkerRestart(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() cluster.workers[0].restart() self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() cluster.workers[0].restart() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndMultiWorkerRestart(self): num_workers = 2 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.workers[worker_index].restart() for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.workers[worker_index].restart() for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = data_service_test_base.TestCluster( num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join() @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 2 * multiprocessing.cpu_count() + 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.add_worker() # Wait for the new worker to register with the dispatcher. while cluster.num_registered_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]), data_service_test_base.all_cluster_configurations()) ) def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 2 * multiprocessing.cpu_count() + 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.workers[0].restart(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): self.skipTest("b/170910141") cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) ds = range_dataset.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.dispatcher_address(), job_name="test")) iterator = iter(ds) for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) cluster.restart_dispatcher() ds = range_dataset.apply( data_service_ops.distribute(processing_mode="distributed_epoch", service=cluster.dispatcher_address(), job_name="test")) with self.assertRaisesOpError( "already an existing job with that name " "using processing mode <parallel_epochs>"): next(iter(ds)).numpy() @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR]))) def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster = data_service_test_base.TestCluster(num_workers=0, work_dir=work_dir, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) it = iter(ds) cluster.add_worker() self.assertAllEqual(next(it), tensor)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidCompression(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "Invalid compression argument"): self.make_distributed_range_dataset(10, cluster, compression="foo") @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = data_service_test_base.TestCluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( init_source=["textfile", "keyvaluetensor", "dataset"]))) def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(value_rank=[0, 1]))) def testDistributeMutableHashTable(self, value_rank): def value(v): for _ in range(value_rank): v = [v, v] return v v1 = value(10) v2 = value(11) default_value = value(-1) cluster = data_service_test_base.TestCluster(num_workers=1) table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64, default_value) self.evaluate(table.insert([0, 1], [v1, v2])) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [v1, v2, default_value], requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = data_service_test_base.TestCluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = self.getDatasetOutput(ds) # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.default_test_combinations()) def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testConcurrentEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_datasets = 3 get_nexts = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) get_nexts.append(self.getNext(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = self.evaluate(get_nexts[dataset_ind]()) results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) results = [] for _ in range(num_elements // 5): results.append(self.evaluate(get_next_1())) results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.default_test_combinations()) def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.default_test_combinations()) def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_1())) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(cluster.workers[0].num_tasks(), 2) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.workers[0].num_tasks() > 1: time.sleep(0.1) self.assertEqual(cluster.workers[0].num_tasks(), 1) @combinations.generate(test_base.default_test_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = data_service_test_base.TestCluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.default_test_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.default_test_combinations()) def testDistributeFromInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.default_test_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.default_test_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.default_test_combinations()) def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.NotFoundError, "No credentials factory has been registered for protocol grp"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grp://" + cluster.dispatcher_address())) self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.default_test_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = data_service_test_base.TestCluster( num_workers=cluster_1_size) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) get_next = self.getNext(ds) for _ in range(num_sizes): element = self.evaluate(get_next()) for _ in range(1, cluster_1_size): self.assertAllEqual(self.evaluate(get_next()), element) self.assertEmpty(self.getIteratorOutput(get_next)) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testFromDatasetIdOmitsCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("abcdefghijklmnopqrstuvwxyz")) def to_upper(x): return script_ops.numpy_function( func=lambda x: x.decode("utf-8").upper(), inp=[x], Tout=dtypes.string) dataset = dataset.map(to_upper, num_parallel_calls=dataset_ops.AUTOTUNE) with mock.patch.object(compat, "forward_compatible", return_value=True): dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) # Eager-only as querying `element_spec` is only supported in the eager mode. @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) with mock.patch.object(compat, "forward_compatible", return_value=True): dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) def _testCompressionMismatch(self, dataset): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") with mock.patch.object(compat, "forward_compatible", return_value=False): dataset_id = data_service_ops._register_dataset( cluster.dispatcher.target, dataset=dataset, compression=None) # `compression` is "AUTO" by default. dataset = data_service_ops._from_dataset_id( processing_mode=ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) with self.assertRaises(errors.InvalidArgumentError): self.getDatasetOutput(dataset) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testCompressionDtypeMismatch(self): dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) self._testCompressionMismatch(dataset) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testCompressionShapeMismatch(self): dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]) self._testCompressionMismatch(dataset) # Only test eager mode since nested datasets are not allowed in graph mode. @combinations.generate( combinations.times(test_base.eager_only_combinations())) def testCompressionVariantMismatch(self): # Use a nested dataset as an example of a variant. dataset = dataset_ops.Dataset.from_tensors( dataset_ops.Dataset.range(10)) self._testCompressionMismatch(dataset) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidCompression(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "Invalid `compression` argument"): self.make_distributed_range_dataset(10, cluster, compression="foo") @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = data_service_test_base.TestCluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( init_source=["textfile", "keyvaluetensor", "dataset"]))) def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(value_rank=[0, 1]))) def testDistributeMutableHashTable(self, value_rank): def value(v): for _ in range(value_rank): v = [v, v] return v v1 = value(10) v2 = value(11) default_value = value(-1) cluster = data_service_test_base.TestCluster(num_workers=1) table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64, default_value) self.evaluate(table.insert([0, 1], [v1, v2])) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [v1, v2, default_value], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle_seed=[None, 10]))) def testShuffleOrder(self, shuffle_seed): random_seed.set_random_seed(None) num_elements = 100 cluster = data_service_test_base.TestCluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements, seed=shuffle_seed) ds = self.make_distributed_dataset(ds, cluster) output = self.getDatasetOutput(ds) # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) if shuffle_seed is None: self.assertNotEqual(first_order, second_order) else: self.assertEqual(first_order, second_order) @combinations.generate(test_base.default_test_combinations()) def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testConcurrentEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_datasets = 3 get_nexts = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) get_nexts.append(self.getNext(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = self.evaluate(get_nexts[dataset_ind]()) results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testEmptyJobNameDistribute(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "`job_name` must not be empty"): dataset_ops.Dataset.range(10).apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name="")) @combinations.generate(test_base.default_test_combinations()) def testEmptyJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "`job_name` must not be empty"): data_service_ops.from_dataset_id(dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name="") @combinations.generate(test_base.default_test_combinations()) def testExplicitProtocolFromDatasetId(self): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") range_ds = dataset_ops.Dataset.range(10) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, range_ds) ds = data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", element_spec=range_ds.element_spec, service=cluster.dispatcher.target, data_transfer_protocol="grpc") self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testNonStringJobNameDistribute(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "`job_name` must be a string"): dataset_ops.Dataset.range(10).apply( data_service_ops.distribute( processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name=constant_op.constant("foo"))) @combinations.generate(test_base.default_test_combinations()) def testNonStringJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "`job_name` must be a string"): data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name=constant_op.constant("foo")) @combinations.generate(test_base.default_test_combinations()) def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) results = [] for _ in range(num_elements // 5): results.append(self.evaluate(get_next_1())) results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.default_test_combinations()) def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.default_test_combinations()) def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_1())) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = self.make_distributed_range_dataset(10, cluster, job_name="job_name") num_epochs = 5 for _ in range(num_epochs): get_next = self.getNext(dataset) self.assertEqual(self.getIteratorOutput(get_next), list(range(10))) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(cluster.workers[0].num_tasks(), 2) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.workers[0].num_tasks() > 1: time.sleep(0.1) self.assertEqual(cluster.workers[0].num_tasks(), 1) @combinations.generate(test_base.eager_only_combinations()) def testGcAndRecreate(self): cluster = data_service_test_base.TestCluster( num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 1000 # Repeatedly create and garbage-collect the same job. for _ in range(3): ds = self.make_distributed_range_dataset(num_elements, cluster, job_name="test") it = iter(ds) for _ in range(50): next(it) del it # Wait for the task to be garbage-collected on all workers. while cluster.num_tasks_on_workers() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testGcClient(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=50)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=10000)) get_next = self.getNext(dataset) # The client does not heartbeat in 10 seconds. It will be garbage-collected. with self.assertRaisesRegex(errors.NotFoundError, "Unknown job client id"): self.evaluate(get_next()) time.sleep(3) self.getIteratorOutput(get_next) @combinations.generate(test_base.eager_only_combinations()) def testKeepClientAliveBeforeReading(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=1000)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=100)) get_next = self.getNext(dataset) # The client regularly heartbeats in 100 milliseconds. It should not be # garbage-collected even if it does not start reading in 3 seconds. time.sleep(3) self.assertEqual(self.getIteratorOutput(get_next), list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = options_lib.Options() opts.deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = options_lib.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = data_service_test_base.TestCluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(external_state_policy=[ options_lib.ExternalStatePolicy.IGNORE, options_lib.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.default_test_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(options_lib.ExternalStatePolicy.FAIL) @combinations.generate(test_base.default_test_combinations()) def testDistributeFromInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(x): dataset = dataset_ops.Dataset.range(10 * x, 10 * x + 2) dataset = self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 10, 1, 11]) @combinations.generate(test_base.default_test_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "`service` must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.default_test_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "`service` must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.default_test_combinations()) def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.NotFoundError, "No credentials factory has been registered for protocol grp"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grp://" + cluster.dispatcher_address())) self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( ValueError, "should be a `tf.data.experimental.service.ShardingPolicy`, " "`\"parallel_epochs\"`, or " "`\"distributed_epoch\"`. Got 'invalid'."): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdSharedJobs(self): cluster = data_service_test_base.TestCluster(num_workers=2) datasets = [ dataset_ops.Dataset.range(20, output_type=dtypes.int32), dataset_ops.Dataset.from_tensor_slices(list(range(20, 40))) ] dataset_ids = [] for ds in datasets: dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) dataset_ids.append(dataset_id) # Read from both jobs in parallel, with 2 consumers for each job. data_service_datasets = [] for _ in range(2): for dataset, dataset_id in zip(datasets, dataset_ids): ds = self.from_dataset_id("distributed_epoch", cluster, dataset_id, dataset.element_spec, job_name="shared_job") data_service_datasets.append(ds) ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets) ds = ds.interleave(lambda x: x, cycle_length=len(data_service_datasets)) self.assertDatasetProduces(ds, list(range(40)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testRegisteringDatasetAsTfFunction(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) register_func = def_function.function(self.register_dataset) dataset_id = register_func( (constant_op.constant("grpc"), constant_op.constant(cluster.dispatcher_address())), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, wrong_spec) if data_service_test_base.TRANSFER_PROTOCOL.value: with self.assertRaisesRegex(errors.InvalidArgumentError, "Data type mismatch at component 0"): self.evaluate(self.getNext(from_dataset_id_ds)()) else: with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id 0 not found"): from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, element_spec) self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.default_test_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2) self.assertEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = data_service_test_base.TestCluster( num_workers=cluster_1_size) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) get_next = self.getNext(ds) for _ in range(num_sizes): element = self.evaluate(get_next()) for _ in range(1, cluster_1_size): self.assertAllEqual(self.evaluate(get_next()), element) self.assertEmpty(self.getIteratorOutput(get_next)) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor]) @combinations.generate( combinations.times(test_base.graph_only_combinations(), combinations.combine(use_resource=False)) + combinations.times(test_base.default_test_combinations(), combinations.combine(use_resource=True))) def testVariables(self, use_resource): cluster = data_service_test_base.TestCluster(num_workers=1) if not use_resource: with variable_scope.variable_scope("foo", use_resource=False): v = variables.VariableV1(10, dtype=dtypes.int64) else: v = variables.Variable(10, dtype=dtypes.int64) ds = dataset_ops.Dataset.range(3) ds = ds.map(lambda x: x + v) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(v.initializer) self.assertDatasetProduces(ds, list(range(10, 13)), requires_initialization=True) @combinations.generate(test_base.graph_only_combinations()) def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) with self.assertRaisesRegex( ValueError, "In graph mode `element_spec` must be provided manually."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testElementSpecMixedMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) @def_function.function def get_dataset_id(): return data_service_ops.register_dataset( cluster.dispatcher_address(), ds) dataset_id = get_dataset_id() dataset_id_val = tensor_util.constant_value(dataset_id) with self.assertRaisesRegex( ValueError, "Failed to fetch element spec for dataset id " + str(dataset_id_val) + " from tf.data service. If the " "dataset was registered in graph mode or inside a " "tf.function, the `element_spec` must be specified as " "an argument to `from_dataset_id`."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) @combinations.generate(test_base.default_test_combinations()) def testNoShardingPolicy(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testCardinality(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = self.make_distributed_range_dataset(10, cluster) self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.UNKNOWN)