def testDispatcherNumWorkers(self): dispatcher = server_lib.DispatchServer() self.assertEqual(0, dispatcher._num_workers()) worker1 = server_lib.WorkerServer( # pylint: disable=unused-variable server_lib.WorkerConfig(dispatcher._address)) self.assertEqual(1, dispatcher._num_workers()) worker2 = server_lib.WorkerServer( # pylint: disable=unused-variable server_lib.WorkerConfig(dispatcher._address)) self.assertEqual(2, dispatcher._num_workers())
def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=dispatcher_port), start=False) worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), port=0), start=False) def start_servers(): time.sleep(1) dispatcher.start() worker.start() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = _make_distributed_range_dataset(num_elements, dispatcher) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join()
def add_worker(self, start=True): self.workers.append( server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, dispatcher_timeout_ms=1000), start=start))
def testStartWorkerWithPortConfig(self): dispatcher = server_lib.DispatchServer() port = pick_unused_port() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address, port=port), start=True) self.assertEqual(worker._address, "localhost:{}".format(port))
def testGcClient(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=50)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=10000)) get_next = self.getNext(dataset) # The client does not heartbeat in 10 seconds. It will be garbage-collected. with self.assertRaisesRegex(errors.NotFoundError, "Unknown job client id"): self.evaluate(get_next()) time.sleep(3) self.getIteratorOutput(get_next)
def testKeepClientAliveBeforeReading(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=1000)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute( processing_mode=data_service_ops.ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=100)) get_next = self.getNext(dataset) # The client regularly heartbeats in 100 milliseconds. It should not be # garbage-collected even if it does not start reading in 3 seconds. time.sleep(3) self.assertEqual(self.getIteratorOutput(get_next), list(range(num_elements)))
def testStopStartWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) worker._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): worker.start()
def testProfileWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) # Test the profilers are successfully started and connected to profiler # service on the worker. Since there is no op running, it is expected to # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) self.assertStartsWith(str(error.exception), "No trace event was collected")
def start_workers(self): self._workers = [] for _ in range(self._num_workers): worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self._dispatcher_address), start=True) self._workers.append(worker) self._pipe_writer.send("Remote workers are ready.") for worker in self._workers: worker.join()
def restart_worker(self, worker_index=0, use_same_port=True): """Replaces the worker at index `worker_index` with a new worker.""" worker = self.workers[worker_index] port = 0 if use_same_port: port = int(worker._address.split(":")[1]) worker._stop() self.workers[worker_index] = server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), port=port, heartbeat_interval_ms=worker._config.heartbeat_interval_ms))
def testDistributeLargeGraphThenRegisterWorker(self, work_dir): dispatcher = self.start_dispatch_server(work_dir=work_dir, fault_tolerant_mode=False) worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(dispatcher), port=0), start=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, dispatcher) it = iter(ds) worker.start() self.assertAllEqual(next(it), tensor)
def _make_worker(dispatcher_address, shutdown_quiet_period_ms=0, port=0): """Creates a worker server.""" defaults = server_lib.WorkerConfig(dispatcher_address=dispatcher_address) config_proto = service_config_pb2.WorkerConfig( dispatcher_address=dispatcher_address, worker_address=defaults.worker_address, port=port, protocol=PROTOCOL, heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, dispatcher_timeout_ms=TEST_DISPATCHER_TIMEOUT_MS, data_transfer_protocol=None, shutdown_quiet_period_ms=shutdown_quiet_period_ms) return server_lib.WorkerServer(config_proto, start=False)
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=_address_from_target( dispatcher.target), port=port))
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(dispatcher), port=port, heartbeat_interval_ms=200))
def start_worker_server(self, dispatcher, port=0): return server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), port=port, heartbeat_interval_ms=200))
def testJoinWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) worker._stop() worker.join()
def testMultipleStartWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address), start=True) worker.start()
def testStartWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address), start=False) worker.start()