def testMasterNumWorkers(self): master = server_lib.MasterServer(0) self.assertEqual(0, master._num_workers()) worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(1, master._num_workers()) worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(2, master._num_workers())
def testAddWorkerMidJob(self): self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master.target) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address) # Wait for the new worker to register with the master. while self._master.num_tasks() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results)
def testStopStartMaster(self): master = server_lib.MasterServer(0) master._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): master.start()
def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # The dataset starts over now that we read from the new worker. for i in range(num_elements): val = next(iterator).numpy() if val == midpoint and i != midpoint: # There may have been one last element prefetched from the first worker # before it was stopped. val = next(iterator).numpy() self.assertEqual(i, val)
def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._worker = server_lib.WorkerServer(PROTOCOL, master_address=master_address) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute(self._master.target, task_refresh_interval_hint_ms=20)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) # Read halfway through the dataset. for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: worker_address = self._worker.target[len(PROTOCOL + "://"):] port = int(worker_address.split(":")[1]) self._worker.stop() self._new_worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address, port=port) # There may be one last element prefetched from the first worker before it # was stopped. val = next(iterator).numpy() self.assertTrue(val == 0 or val == num_elements // 2) start_val = 1 if val == 0 else 0 # The dataset starts over now that we read from the new worker. for i in range(start_val, num_elements): self.assertEqual(i, next(iterator).numpy())
def testAddWorkerMidJob(self): self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._worker = server_lib.WorkerServer(PROTOCOL, master_address=master_address) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute(self._master.target, task_refresh_interval_hint_ms=20)) token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = data_service_ops.create_iterator(ds, token) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( PROTOCOL, master_address=master_address) # Give the client time to notice the new task. time.sleep(50 / 1000) # 50ms for elem in iterator: results.append(elem.numpy()) # It is possible that reading from the first worker completes before the # client notices the second worker. We allow this to avoid flaky failures. if len(results) == num_elements: self.assertEqual(list(range(num_elements)), results) else: self.assertCountEqual(2 * list(range(num_elements)), results)
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: The address of the master. """ self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer(port=0, master_address=self._master._address, protocol=PROTOCOL)) return self._master._address
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: A target for connecting to the service, e.g. "grpc+local://localhost:2000". """ self._master = server_lib.MasterServer(PROTOCOL) master_address = self._master.target[len(PROTOCOL + "://"):] self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer(PROTOCOL, master_address=master_address)) return self._master.target
def testStartMaster(self): master = server_lib.MasterServer(0, start=False) master.start()
def testJoinWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address) worker._stop() worker.join()
def testJoinMaster(self): master = server_lib.MasterServer(0) master._stop() master.join()
def testStopMaster(self): master = server_lib.MasterServer(0) master._stop() master._stop()
def testMultipleStartWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address, start=True) worker.start()
def testStartWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address, start=False) worker.start()
def testMultipleStartMaster(self): master = server_lib.MasterServer(0, start=True) master.start()
def testStartWorker(self): master = server_lib.MasterServer(PROTOCOL) worker = server_lib.WorkerServer(PROTOCOL, master.target[len(PROTOCOL + "://"):]) self.assertRegex(worker.target, PROTOCOL + "://.*:.*")
def testStartMaster(self): master = server_lib.MasterServer(PROTOCOL) self.assertRegex(master.target, PROTOCOL + "://.*:.*")