def testOneLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=5) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, target_workers="local") self.assertDatasetProduces(ds, list(range(num_elements)))
def testAddRemoteWorkersMidJob(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers, worker_tags=[_COLOCATED_WORKER_TAG]) num_elements = 300 dataset = self.make_distributed_range_dataset(num_elements, cluster) get_next = self.getNext(dataset) results = [self.evaluate(get_next()) for _ in range(100)] # Will only read from the two non-TPU workers. cluster.start_remote_worker(worker_tags=None) cluster.start_remote_worker(worker_tags=[_COLOCATED_WORKER_TAG]) cluster.start_remote_worker(worker_tags=None) cluster.start_remote_worker(worker_tags=[_COLOCATED_WORKER_TAG]) expect_num_workers_to_read = num_local_workers + 2 # Wait for the new worker to register with the dispatcher. while cluster._dispatcher._num_workers() < (num_local_workers + num_remote_workers + 4): time.sleep(10 / 1000) # 10ms results += self.getIteratorOutput(get_next) self.assertCountEqual( results, expect_num_workers_to_read * list(range(num_elements)))
def testInvalidTag(self): with self.assertRaisesRegex(RuntimeError, "Worker tags cannot be empty."): _ = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=3, worker_tags=["", _COLOCATED_WORKER_TAG])
def testAnonymousJobWithDifferentTargetWorkers(self): num_local_workers, num_remote_workers = (3, 3) cluster = multi_process_cluster.MultiProcessCluster( num_local_workers, num_remote_workers) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) datasets = { target_workers: self.make_distributed_dataset(ds, cluster, target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] } num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces(datasets["AUTO"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["ANY"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["LOCAL"], num_local_workers * list(range(num_elements)), assert_items_equal=True)
def testEmptyDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 0 ds = self.make_distributed_range_dataset( num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(ds, [])
def testRepeatDistributedDataset(self, num_remote_workers, job_name): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) dataset = self.make_distributed_range_dataset( 10, cluster, job_name=job_name, target_workers="LOCAL") dataset = dataset.repeat(3) self.assertDatasetProduces(dataset, list(range(10)) * 3)
def testReadFromLocalWorker(self, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=num_remote_workers, worker_tags=[_COLOCATED_WORKER_TAG]) num_elements = 100 dataset = self.make_distributed_range_dataset(num_elements, cluster) # Only reads from the local worker. self.assertDatasetProduces(dataset, list(range(num_elements)))
def testNoLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, target_workers="LOCAL") with self.assertRaisesRegex(errors.InvalidArgumentError, "no local worker is found"): get_next = self.getNext(ds) self.evaluate(get_next())
def testCoordinatedRead(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers, worker_tags=[_COLOCATED_WORKER_TAG]) num_consumers = 4 dataset = self.make_coordinated_read_dataset(cluster, num_consumers) get_next = self.getNext(dataset) results = [self.evaluate(get_next()) for _ in range(200)] self.checkCoordinatedReadGroups(results, num_consumers)
def testLocalWorkers(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces( ds, num_local_workers * list(range(num_elements)), assert_items_equal=True)
def testUnusedTags(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=3, worker_tags=["Unused tag 1", "Unused tag 2", "Unused tag 3"]) num_elements = 100 dataset = self.make_distributed_range_dataset(num_elements, cluster) # The tags don't have an effect. tf.data service will read from all workers. self.assertDatasetProduces(dataset, 4 * list(range(num_elements)), assert_items_equal=True)
def testNonLocalRead(self, num_local_workers, num_remote_workers): """This test ensures the remote workers are running and producing data.""" cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces( ds, num_workers * list(range(num_elements)), assert_items_equal=True)
def testDynamicSharding(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode=data_service_ops.ShardingPolicy.DYNAMIC, target_workers="LOCAL") self.assertDatasetProduces( ds, list(range(num_elements)), assert_items_equal=True)
def testCluster(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 num_workers = num_local_workers + num_remote_workers if num_workers == 0: return dataset = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(dataset, num_workers * list(range(num_elements)), assert_items_equal=True)
def testNoLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) num_elements = 10 ds = self.make_distributed_range_dataset( num_elements, cluster, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads require local tf.data workers, but no local worker is " "found."): self.getDatasetOutput(ds)
def testMultipleEpochs(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_epochs, num_steps = 5, 5 dataset = self._make_distributed_infinite_range_dataset(cluster) for _ in range(num_epochs): # For each iteration, the previous iterator is garbage collected. get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i)
def testRepeatedDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_local_workers * num_repetitions * list(range(num_elements)), assert_items_equal=True)
def testCoordinatedRead(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10).repeat() ds = self.make_distributed_dataset(ds, cluster, job_name="test_job", consumer_index=0, num_consumers=3, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Coordinated reads require non-local workers"): self.getDatasetOutput(ds)
def testReadFromLocalAndNonTpuWorkers(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers, worker_tags=[_COLOCATED_WORKER_TAG]) cluster.start_remote_worker(worker_tags=None) num_elements = 100 dataset = self.make_distributed_range_dataset(num_elements, cluster) # Reads from the local worker or non-colocated worker. self.assertDatasetProduces(dataset, (num_local_workers + 1) * list(range(num_elements)), assert_items_equal=True)
def testNoLocalWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads or static sharding require local tf.data workers" ): self.getDatasetOutput(dataset)
def testReadFromLocalWorker_StaticSharding(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=3, worker_addresses=["localhost:%port%"] * 5, worker_tags=[_COLOCATED_WORKER_TAG]) cluster.start_remote_worker(worker_tags=None) num_elements = 100 dataset = self.make_distributed_range_dataset( num_elements, cluster, processing_mode=data_service_ops.ShardingPolicy.FILE_OR_DATA) # Static sharding will only read from the local worker. self.assertDatasetProduces(dataset, list(range(0, num_elements, 5)))
def testReadFromLocalAndNonTpuWorkers_DynamicSharding( self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=3, worker_tags=[_COLOCATED_WORKER_TAG]) cluster.start_remote_worker(worker_tags=None) num_elements = 100 dataset = self.make_distributed_range_dataset( num_elements, cluster, processing_mode=data_service_ops.ShardingPolicy.DYNAMIC) self.assertDatasetProduces(dataset, list(range(num_elements)), assert_items_equal=True)
def testInconsistentTargetWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10) datasets = [ self.make_distributed_dataset( ds, cluster, job_name="test_job", target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] ] with self.assertRaisesRegex( errors.InvalidArgumentError, "but found an existing job with different parameters: " "Existing target workers: <AUTO>"): for dataset in datasets: self.getDatasetOutput(dataset)
def testPreferLocalRead(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=0) num_elements = 100 dataset = self.make_distributed_range_dataset(num_elements, cluster) get_next = self.getNext(dataset) self.assertEqual(self.evaluate(get_next()), 0) for i in range(1, 4): cluster.start_remote_worker() # Waits for the new worker to register with the dispatcher. while cluster._dispatcher._num_workers() < i + 1: time.sleep(10 / 1000) # 10ms # Prefers reading from the local worker. self.assertEqual(self.evaluate(get_next()), i) self.assertCountEqual( self.getIteratorOutput(get_next), list(range(4, num_elements)) + 3 * list(range(num_elements)))
def testMultipleEpochs_DispatcherRestart(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Verifies the worker re-creates the task after the iterator is deleted and # the dispatcher restarts. del get_next cluster.restart_dispatcher() get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i)
def testReadFromDeletedTask(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Re-creating the dataset resets the iterator index, so the second iterator # reads from the same task as the first, which has been deleted. dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) with self.assertRaisesRegex(errors.FailedPreconditionError, "which has been deleted."): _ = self.evaluate(get_next())
def testMultipleConsumers(self): num_local_workers, num_remote_workers = 1, 3 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 300 num_consumers = 8 iterators = [] for _ in range(num_consumers): dataset = self.make_distributed_range_dataset( num_elements, cluster, job_name="shared_job") iterators.append(self.getNext(dataset)) results = [] for _ in range(10): for it in iterators: results.append(self.evaluate(it())) for it in iterators: results.extend(self.getIteratorOutput(it)) self.assertCountEqual(results, (num_local_workers + num_remote_workers) * list(range(num_elements)))