def testMasterNumWorkers(self): master = server_lib.MasterServer(0) self.assertEqual(0, master._num_workers()) worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(1, master._num_workers()) worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable self.assertEqual(2, master._num_workers())
def testDispatcherNumWorkers(self): dispatcher = server_lib.DispatchServer(0) self.assertEqual(0, dispatcher._num_workers()) worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable self.assertEqual(1, dispatcher._num_workers()) worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable self.assertEqual(2, dispatcher._num_workers())
def testAddWorkerMidJob(self): self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._dispatcher._address) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) # Wait for the new worker to register with the dispatcher. while self._dispatcher._num_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results)
def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val)
def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # The dataset starts over now that we read from the new worker. for i in range(num_elements): val = next(iterator).numpy() if val == midpoint and i != midpoint: # There may have been one last element prefetched from the first worker # before it was stopped. val = next(iterator).numpy() self.assertEqual(i, val)
def testStartWorkerWithPortConfig(self): dispatcher = server_lib.DispatchServer() port = pick_unused_port() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address, port=port), start=True) self.assertEqual(worker._address, "localhost:{}".format(port))
def testGcClient(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=50)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=10000)) get_next = self.getNext(dataset) # The client does not heartbeat in 10 seconds. It will be garbage-collected. with self.assertRaisesRegex(errors.NotFoundError, "Unknown job client id"): self.evaluate(get_next()) time.sleep(3) self.getIteratorOutput(get_next)
def testKeepClientAliveBeforeReading(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=1000)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute( processing_mode=data_service_ops.ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=100)) get_next = self.getNext(dataset) # The client regularly heartbeats in 100 milliseconds. It should not be # garbage-collected even if it does not start reading in 3 seconds. time.sleep(3) self.assertEqual(self.getIteratorOutput(get_next), list(range(num_elements)))
def add_worker(self, start=True): self.workers.append( server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, dispatcher_timeout_ms=1000), start=start))
def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=dispatcher_port), start=False) worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), port=0), start=False) def start_servers(): time.sleep(1) dispatcher.start() worker.start() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = _make_distributed_range_dataset(num_elements, dispatcher) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join()
def testStopStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): worker.start()
def testProfileWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) # Test the profilers are successfully started and connected to profiler # service on the worker. Since there is no op running, it is expected to # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) self.assertStartsWith(str(error.exception), "No trace event was collected")
def start_workers(self): self._workers = [] for _ in range(self._num_workers): worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self._dispatcher_address), start=True) self._workers.append(worker) self._pipe_writer.send("Remote workers are ready.") for worker in self._workers: worker.join()
def restart_worker(self, worker_index=0, use_same_port=True): """Replaces the worker at index `worker_index` with a new worker.""" worker = self.workers[worker_index] port = 0 if use_same_port: port = int(worker._address.split(":")[1]) worker._stop() self.workers[worker_index] = server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), port=port, heartbeat_interval_ms=worker._config.heartbeat_interval_ms))
def _make_worker(dispatcher_address, shutdown_quiet_period_ms=0, port=0): """Creates a worker server.""" defaults = server_lib.WorkerConfig(dispatcher_address=dispatcher_address) config_proto = service_config_pb2.WorkerConfig( dispatcher_address=dispatcher_address, worker_address=defaults.worker_address, port=port, protocol=PROTOCOL, heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, dispatcher_timeout_ms=TEST_DISPATCHER_TIMEOUT_MS, data_transfer_protocol=None, shutdown_quiet_period_ms=shutdown_quiet_period_ms) return server_lib.WorkerServer(config_proto, start=False)
def testDistributeLargeGraphThenRegisterWorker(self, work_dir): dispatcher = self.start_dispatch_server(work_dir=work_dir, fault_tolerant_mode=False) worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(dispatcher), port=0), start=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, dispatcher) it = iter(ds) worker.start() self.assertAllEqual(next(it), tensor)
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: The address of the master. """ self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL)) return self._master._address
def testDispatcherPreemption(self): self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset( ds, "{}://{}".format(PROTOCOL, self._dispatcher._address)) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) self._dispatcher._stop() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements)))
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: A string for connecting to the tf.data service. """ self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)) return "{0}://{1}".format(PROTOCOL, self._dispatcher._address)
def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = _make_distributed_dataset( ds, "{}://{}".format(PROTOCOL, self._dispatcher._address)) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next()))
def testJoinWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() worker.join()
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=_address_from_target( dispatcher.target), port=port))
def testMultipleStartWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address), start=True) worker.start()
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer(port=port, dispatcher_address=_address_from_target( dispatcher.target), protocol=server_lib.DEFAULT_PROTOCOL)
def testStopWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) worker._stop() worker._stop()
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(dispatcher), port=port, heartbeat_interval_ms=200))
def testStartWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address), start=False) worker.start()
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), port=port, heartbeat_interval_ms=200))
def testMultipleStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address, start=True) worker.start()
def testStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address, start=False) worker.start()