def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._worker = server_lib.WorkerServer(PROTOCOL, master_address=master_address) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute(self._master.target, task_refresh_interval_hint_ms=20)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) # Read halfway through the dataset. for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: worker_address = self._worker.target[len(PROTOCOL + "://"):] port = int(worker_address.split(":")[1]) self._worker.stop() self._new_worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address, port=port) # There may be one last element prefetched from the first worker before it # was stopped. val = next(iterator).numpy() self.assertTrue(val == 0 or val == num_elements // 2) start_val = 1 if val == 0 else 0 # The dataset starts over now that we read from the new worker. for i in range(start_val, num_elements): self.assertEqual(i, next(iterator).numpy())
def testAddWorkerMidJob(self): self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._worker = server_lib.WorkerServer(PROTOCOL, master_address=master_address) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute(self._master.target, task_refresh_interval_hint_ms=20)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address) # Give the client time to notice the new task. time.sleep(50 / 1000) # 50ms for elem in iterator: results.append(elem.numpy()) # It is possible that reading from the first worker completes before the # client notices the second worker. We allow this to avoid flaky failures. if len(results) == num_elements: self.assertEqual(list(range(num_elements)), results) else: self.assertCountEqual(2 * list(range(num_elements)), results)
def testMultipleEpochs(self): service = self.create_cluster(1) ds = dataset_ops.Dataset.range(3) ds = ds.apply(data_service_ops.distribute(service)) for _ in range(10): token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) self.assertEqual(list(range(3)), [t.numpy() for t in it])
def testDistributeBasic(self): num_elements = 10 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) results = [t.numpy() for t in it] self.assertEqual(list(range(num_elements)), results)
def testMultiWorker(self): num_workers = 3 num_elements = 10 service = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) results = [elem.numpy() for elem in iterator] self.assertCountEqual(num_workers * list(range(num_elements)), results)
def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) service = self.create_cluster(3) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) next(iterator)
def f(): ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in it: result = result.write(i, elem) i += 1 return result.stack()
def testConcurrentEpoch(self): num_elements = 10 num_datasets = 3 service = self.create_cluster(1) iterators = [] results = [] for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) token = data_service_ops.create_job( ds, processing_mode="parallel_epochs") it = data_service_ops.create_iterator(ds, token) iterators.append(it) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result)
def testSharedEpoch(self): num_elements = 10 num_iterators = 3 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply(data_service_ops.distribute(service)) result = [] iterators = [] token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") for _ in range(num_iterators): iterators.append(data_service_ops.create_iterator(ds, token)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result)