def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = data_service_test_base.TestCluster( num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join()
def testRoundRobinRestartWorker(self): num_workers = 3 # Set a shutdown quiet period to prevent workers from shutting down partway # through a round. cluster = data_service_test_base.TestCluster( num_workers, worker_shutdown_quiet_period_ms=2000) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 5 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] self.read(get_next, results, 20) cluster.workers[1].stop() # Check that we can continue to read even with a worker stopped. self.read(get_next, results, 20) cluster.workers[1].restart() # Read until we get results from the restarted worker, then read some more. while results[-1] != 0: results.append(self.evaluate(get_next())) self.read(get_next, results, 20) self.checkRoundRobinGroups(results, num_consumers)
def testDistributeDistributedEpochTensorSlices(self): cluster = data_service_test_base.TestCluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True)
def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results)
def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))
def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.workers[0].restart(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val)
def testDispatcherRestartBetweenIterations(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements)))
def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements)))
def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertNotEqual(id_1.numpy(), id_2.numpy())
def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements)
def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces( ds, num_workers * list(range(num_elements)), assert_items_equal=True)
def testRoundRobinAddWorkers(self, workers_to_add): starting_workers = 3 cluster = data_service_test_base.TestCluster( num_workers=starting_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 7 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] zeros_seen = 0 for _ in range(25): results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 for _ in range(workers_to_add): cluster.add_worker() # Read until all new workers have joined. while zeros_seen < starting_workers + workers_to_add: results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 # Read some more. for _ in range(25): results.append(self.evaluate(get_next())) self.checkRoundRobinGroups(results, num_consumers)
def testRoundRobinMultiStartStop(self): num_workers = 3 # Set a shutdown quiet period to prevent workers from shutting down partway # through a round. cluster = data_service_test_base.TestCluster( num_workers, worker_shutdown_quiet_period_ms=2000) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 5 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] self.read(get_next, results, 20) for i in range(num_workers): cluster.workers[i].stop() self.read(get_next, results, 20) cluster.workers[i].restart() self.read(get_next, results, 20) cluster.add_worker() cluster.restart_dispatcher() for i in range(num_workers): cluster.workers[i].stop() self.read(get_next, results, 20) self.checkRoundRobinGroups(results, num_consumers)
def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results)
def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements)))
def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10)))
def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results)
def testDistributeDistributedEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, list(range(num_elements)), assert_items_equal=True)
def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
def testDistributedEpochOnDistributedDataset(self): cluster_1 = data_service_test_base.TestCluster(num_workers=1) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 numbers = [1 * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(numbers) ds = self.make_distributed_dataset( ds, cluster_1, processing_mode="parallel_epochs") ds = ds.map(lambda x: x + 1) ds = self.make_distributed_dataset( ds, cluster_2, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type DataServiceDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds)
def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces( ds, num_workers * list(range(num_elements)), assert_items_equal=True)
def testDistributeDistributedEpochFlatMap(self): cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset( ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True)
def testRoundRobin(self, num_workers, num_consumers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = self.make_round_robin_dataset(cluster, num_consumers) ds = ds.take(100) results = self.getDatasetOutput(ds) self.checkRoundRobinGroups(results, num_consumers)
def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces( ds, expected_output=num_repetitions * list(range(num_elements)))
def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements)))
def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds))
def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results)