def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=dispatcher_port), start=False) worker = server_lib.WorkerServer(server_lib.WorkerConfig( dispatcher_address=_address_from_target(dispatcher.target), port=0), start=False) def start_servers(): time.sleep(1) dispatcher.start() worker.start() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = _make_distributed_range_dataset(num_elements, dispatcher) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join()
def _start_dispatcher(self, worker_addresses): work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=0, work_dir=work_dir, protocol="grpc", worker_addresses=worker_addresses), start=True)
def _start_dispatcher(self, worker_addresses, port=0): self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self._work_dir, protocol="grpc", worker_addresses=worker_addresses, fault_tolerant_mode=True), start=True)
def restart_dispatcher(self): """Stops `dispatcher` and creates a new dispatcher with the same port.""" port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self.dispatcher._config.work_dir, fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode ))
def __init__(self, num_workers, dispatcher_port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=None, job_gc_timeout_ms=None, worker_shutdown_quiet_period_ms=0, start=True, data_transfer_protocol=None): """Creates a tf.data service test cluster. Args: num_workers: The number of workers to initially add to the cluster. dispatcher_port: The port to use for the dispatcher. work_dir: The work directory to use for the dispatcher. If set to `TMP_WORK_DIR`, the cluster will create a new temporary directory to use as the work directory. If set to `NO_WORK_DIR`, no work directory will be used. fault_tolerant_mode: Whether the dispatcher should write its state to a journal so that it can recover from restarts. job_gc_check_interval_ms: How often the dispatcher should scan through to delete old and unused jobs, in milliseconds. job_gc_timeout_ms: How long a job needs to be unused before it becomes a candidate for garbage collection, in milliseconds. worker_shutdown_quiet_period_ms: When shutting down a worker, how long to wait for the gRPC server to process the final requests. start: Whether to immediately start the servers in the cluster. If `False`, the servers can be started later by calling `start_dispatcher()` and `start_workers()`. data_transfer_protocol: (Optional.) The protocol to use for transferring data with the tf.data service. The default can controlled via tf_data_service_test_transfer_protocol flag. """ if work_dir == TMP_WORK_DIR: work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms if not data_transfer_protocol: data_transfer_protocol = TRANSFER_PROTOCOL.value self._data_transfer_protocol = data_transfer_protocol self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, protocol=PROTOCOL, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms), start=start) self.workers = [] for _ in range(num_workers): self.add_worker(start=start)
def start_dispatch_server(self, name="", port=0, work_dir=None, fault_tolerant_mode=True): # If a test starts multiple independent dispatch servers, it should give # them different `name` values. work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name) if work_dir is None else work_dir return server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode))
def __init__(self, num_local_workers, num_remote_workers, dispatcher_port=0, worker_shutdown_quiet_period_ms=0): work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig(port=dispatcher_port, work_dir=work_dir, protocol="grpc"), start=True) self._local_workers = self.start_local_workers(num_local_workers) self.start_remote_workers(num_remote_workers)
def restart_dispatcher(self): """Stops `dispatcher` and creates a new dispatcher with the same port. Restarting is supported only when the dispatcher is configured with `fault_tolerant_mode=True`. """ if not self.dispatcher._config.fault_tolerant_mode: raise ValueError( "Trying to restart the dispatcher without fault-tolerance.") port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self.dispatcher._config.work_dir, fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode))
def start_dispatch_server(self, name="", port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=None, job_gc_timeout_ms=None): # If a test starts multiple independent dispatch servers, it should give # them different `name` values. work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name) if work_dir is TMP_WORK_DIR else work_dir return server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms))
def __init__(self, num_workers, dispatcher_port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=None, job_gc_timeout_ms=None, start=True): """Creates a tf.data service test cluster. Args: num_workers: The number of workers to initially add to the cluster. dispatcher_port: The port to use for the dispatcher. work_dir: The work directory to use for the dispatcher. If set to `TMP_WORK_DIR`, the cluster will create a new temporary directory to use as the work directory. If set to `NO_WORK_DIR`, no work directory will be used. fault_tolerant_mode: Whether the dispatcher should write its state to a journal so that it can recover from restarts. job_gc_check_interval_ms: How often the dispatcher should scan through to delete old and unused jobs, in milliseconds. job_gc_timeout_ms: How long a job needs to be unused before it becomes a candidate for garbage collection, in milliseconds. start: Whether to immediately start the servers in the cluster. If `False`, the servers can be started later by calling `start_dispatcher()` and `start_workers()`. """ if work_dir == TMP_WORK_DIR: work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms), start=start) self.workers = [] for _ in range(num_workers): self.add_worker(start=start)
def testStartDispatcherWithWrongFaultTolerantConfig(self): config = server_lib.DispatcherConfig(fault_tolerant_mode=True) error = "Cannot enable fault tolerant mode without configuring a work_dir" with self.assertRaisesRegex(ValueError, error): dispatcher = server_lib.DispatchServer( # pylint: disable=unused-variable config=config, start=True)
def testStartDispatcherWithFaultTolerantConfig(self): temp_dir = tempfile.mkdtemp() config = server_lib.DispatcherConfig(work_dir=temp_dir, fault_tolerant_mode=True) dispatcher = server_lib.DispatchServer( # pylint: disable=unused-variable config=config, start=True)
def testStartDispatcherWithWorkDirConfig(self): temp_dir = tempfile.mkdtemp() config = server_lib.DispatcherConfig(work_dir=temp_dir) dispatcher = server_lib.DispatchServer( # pylint: disable=unused-variable config=config, start=True)
def testStartDispatcherWithPortConfig(self): port = pick_unused_port() config = server_lib.DispatcherConfig(port=port) dispatcher = server_lib.DispatchServer(config=config, start=True) self.assertEqual(dispatcher.target, "grpc://localhost:{}".format(port))
def _start_dispatcher(self, dispatcher_port): work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, protocol="grpc"), start=True)