def testGcClient(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=50)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=10000)) get_next = self.getNext(dataset) # The client does not heartbeat in 10 seconds. It will be garbage-collected. with self.assertRaisesRegex(errors.NotFoundError, "Unknown job client id"): self.evaluate(get_next()) time.sleep(3) self.getIteratorOutput(get_next)
def testKeepClientAliveBeforeReading(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=1000)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute( processing_mode=data_service_ops.ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=100)) get_next = self.getNext(dataset) # The client regularly heartbeats in 100 milliseconds. It should not be # garbage-collected even if it does not start reading in 3 seconds. time.sleep(3) self.assertEqual(self.getIteratorOutput(get_next), list(range(num_elements)))
def testStopStartDispatcher(self): dispatcher = server_lib.DispatchServer() dispatcher._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): dispatcher.start()
def testStopStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): worker.start()
def testStartWorkerWithPortConfig(self): dispatcher = server_lib.DispatchServer() port = pick_unused_port() worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher._address, port=port), start=True) self.assertEqual(worker._address, "localhost:{}".format(port))
def start_dispatch_server(self, name="", port=0): # If a test starts multiple independent dispatch servers, it should give # them different `name` values. work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name) return server_lib.DispatchServer(port=port, protocol=server_lib.DEFAULT_PROTOCOL, work_dir=work_dir, fault_tolerant_mode=True)
def _start_dispatcher(self, worker_addresses): work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=0, work_dir=work_dir, protocol="grpc", worker_addresses=worker_addresses), start=True)
def _start_dispatcher(self, worker_addresses, port=0): self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self._work_dir, protocol="grpc", worker_addresses=worker_addresses, fault_tolerant_mode=True), start=True)
def testProfileWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) # Test the profilers are successfully started and connected to profiler # service on the worker. Since there is no op running, it is expected to # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) self.assertStartsWith(str(error.exception), "No trace event was collected")
def restart_dispatcher(self): """Stops `dispatcher` and creates a new dispatcher with the same port.""" port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self.dispatcher._config.work_dir, fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode ))
def __init__(self, num_workers, dispatcher_port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=None, job_gc_timeout_ms=None, worker_shutdown_quiet_period_ms=0, start=True, data_transfer_protocol=None): """Creates a tf.data service test cluster. Args: num_workers: The number of workers to initially add to the cluster. dispatcher_port: The port to use for the dispatcher. work_dir: The work directory to use for the dispatcher. If set to `TMP_WORK_DIR`, the cluster will create a new temporary directory to use as the work directory. If set to `NO_WORK_DIR`, no work directory will be used. fault_tolerant_mode: Whether the dispatcher should write its state to a journal so that it can recover from restarts. job_gc_check_interval_ms: How often the dispatcher should scan through to delete old and unused jobs, in milliseconds. job_gc_timeout_ms: How long a job needs to be unused before it becomes a candidate for garbage collection, in milliseconds. worker_shutdown_quiet_period_ms: When shutting down a worker, how long to wait for the gRPC server to process the final requests. start: Whether to immediately start the servers in the cluster. If `False`, the servers can be started later by calling `start_dispatcher()` and `start_workers()`. data_transfer_protocol: (Optional.) The protocol to use for transferring data with the tf.data service. The default can controlled via tf_data_service_test_transfer_protocol flag. """ if work_dir == TMP_WORK_DIR: work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms if not data_transfer_protocol: data_transfer_protocol = TRANSFER_PROTOCOL.value self._data_transfer_protocol = data_transfer_protocol self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, protocol=PROTOCOL, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms), start=start) self.workers = [] for _ in range(num_workers): self.add_worker(start=start)
def _start_dispatcher(self, worker_addresses, port=0): if port == 0: port = test_util.pick_unused_port() self._dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig( port=port, protocol="grpc", work_dir=self._work_dir, fault_tolerant_mode=True, worker_addresses=worker_addresses, deployment_mode=self._deployment_mode), start=True)
def __init__(self, num_local_workers, num_remote_workers, dispatcher_port=0, worker_shutdown_quiet_period_ms=0): work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=dispatcher_port, work_dir=work_dir, protocol="grpc"), start=True) self._local_workers = self.start_local_workers(num_local_workers) self.start_remote_workers(num_remote_workers)
def restart_dispatcher(self): """Stops `dispatcher` and creates a new dispatcher with the same port. Restarting is supported only when the dispatcher is configured with `fault_tolerant_mode=True`. """ if not self.dispatcher._config.fault_tolerant_mode: raise ValueError( "Trying to restart the dispatcher without fault-tolerance.") port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self.dispatcher._config.work_dir, fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode))
def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: A string for connecting to the tf.data service. """ self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)) return "{0}://{1}".format(PROTOCOL, self._dispatcher._address)
def testDispatcherPreemption(self): self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset( ds, "{}://{}".format(PROTOCOL, self._dispatcher._address)) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) self._dispatcher._stop() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements)))
def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = _make_distributed_dataset( ds, "{}://{}".format(PROTOCOL, self._dispatcher._address)) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next()))
def __init__(self, num_workers, dispatcher_port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=None, job_gc_timeout_ms=None, start=True): """Creates a tf.data service test cluster. Args: num_workers: The number of workers to initially add to the cluster. dispatcher_port: The port to use for the dispatcher. work_dir: The work directory to use for the dispatcher. If set to `TMP_WORK_DIR`, the cluster will create a new temporary directory to use as the work directory. If set to `NO_WORK_DIR`, no work directory will be used. fault_tolerant_mode: Whether the dispatcher should write its state to a journal so that it can recover from restarts. job_gc_check_interval_ms: How often the dispatcher should scan through to delete old and unused jobs, in milliseconds. job_gc_timeout_ms: How long a job needs to be unused before it becomes a candidate for garbage collection, in milliseconds. start: Whether to immediately start the servers in the cluster. If `False`, the servers can be started later by calling `start_dispatcher()` and `start_workers()`. """ if work_dir == TMP_WORK_DIR: work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms), start=start) self.workers = [] for _ in range(num_workers): self.add_worker(start=start)
def testRestartWorker(self, use_same_port): self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset( ds, "{}://{}".format(PROTOCOL, self._dispatcher._address)) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val)
def testJoinWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() worker.join()
def testJoinDispatcher(self): dispatcher = server_lib.DispatchServer(0) dispatcher._stop() dispatcher.join()
def testStopDispatcher(self): dispatcher = server_lib.DispatchServer(0) dispatcher._stop() dispatcher._stop()
def testMultipleStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address, start=True) worker.start()
def testStartWorker(self): dispatcher = server_lib.DispatchServer(0) worker = server_lib.WorkerServer(0, dispatcher._address, start=False) worker.start()