def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # The dataset starts over now that we read from the new worker. for i in range(num_elements): val = next(iterator).numpy() if val == midpoint and i != midpoint: # There may have been one last element prefetched from the first worker # before it was stopped. val = next(iterator).numpy() self.assertEqual(i, val)
def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val)
def testStopStartMaster(self): master = server_lib.MasterServer(0) master._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): master.start()
def testMasterNumWorkers(self): master = server_lib.MasterServer(0) self.assertEqual(0, master._num_workers()) worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(1, master._num_workers()) worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(2, master._num_workers())
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: The address of the master. """ self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL)) return self._master._address
def testAddWorkerMidJob(self): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) # Wait for the new worker to register with the master. while self._master._num_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results)
def testJoinWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address) worker._stop() worker.join()
def testJoinMaster(self): master = server_lib.MasterServer(0) master._stop() master.join()
def testStopMaster(self): master = server_lib.MasterServer(0) master._stop() master._stop()
def testMultipleStartWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address, start=True) worker.start()
def testStartWorker(self): master = server_lib.MasterServer(0) worker = server_lib.WorkerServer(0, master._address, start=False) worker.start()
def testMultipleStartMaster(self): master = server_lib.MasterServer(0, start=True) master.start()
def testStartMaster(self): master = server_lib.MasterServer(0, start=False) master.start()