def __init__(self, config=None, start=True): """Creates a new dispatch server. Args: config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` configration. If `None`, the dispatcher will use default configuration values. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to True. """ config = config or DispatcherConfig() if config.fault_tolerant_mode and not config.work_dir: raise ValueError( "Cannot enable fault tolerant mode without configuring a work dir. " "Make sure to set `work_dir` in the `config` object passed to " "`DispatcherServer`.") self._config = config if isinstance(config, service_config_pb2.DispatcherConfig): config_proto = config else: config_proto = service_config_pb2.DispatcherConfig( port=config.port, protocol=config.protocol, work_dir=config.work_dir, fault_tolerant_mode=config.fault_tolerant_mode, worker_addresses=config.worker_addresses, job_gc_check_interval_ms=config.job_gc_check_interval_ms, job_gc_timeout_ms=config.job_gc_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( config_proto.SerializeToString()) if start: self._server.start()
def testGcClient(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=50)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute(processing_mode=ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=10000)) get_next = self.getNext(dataset) # The client does not heartbeat in 10 seconds. It will be garbage-collected. with self.assertRaisesRegex(errors.NotFoundError, "Unknown job client id"): self.evaluate(get_next()) time.sleep(3) self.getIteratorOutput(get_next)
def testKeepClientAliveBeforeReading(self): dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig(protocol="grpc", job_gc_check_interval_ms=50, job_gc_timeout_ms=20, client_timeout_ms=1000)) dispatcher_address = dispatcher.target.split("://")[1] _ = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher_address=dispatcher_address, heartbeat_interval_ms=100)) num_elements = 1000 dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( data_service_ops._distribute( processing_mode=data_service_ops.ShardingPolicy.OFF, service=dispatcher.target, task_refresh_interval_hint_ms=100)) get_next = self.getNext(dataset) # The client regularly heartbeats in 100 milliseconds. It should not be # garbage-collected even if it does not start reading in 3 seconds. time.sleep(3) self.assertEqual(self.getIteratorOutput(get_next), list(range(num_elements)))
def __init__(self, config=None, start=True): """Creates a new dispatch server. Args: config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` configration. If `None`, the dispatcher will use default configuration values. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to True. """ config = config or DispatcherConfig() if config.fault_tolerant_mode and not config.work_dir: raise ValueError( "Cannot enable fault tolerant mode without configuring a work_dir" ) self._config = config config_proto = service_config_pb2.DispatcherConfig( port=config.port, protocol=config.protocol, work_dir=config.work_dir, fault_tolerant_mode=config.fault_tolerant_mode, job_gc_check_interval_ms=config.job_gc_check_interval_ms, job_gc_timeout_ms=config.job_gc_timeout_ms, cache_policy=config.cache_policy, cache_format=config.cache_format, cache_compression=config.cache_compression, cache_ops_parallelism=config.cache_ops_parallelism, cache_path=config.cache_path, scaling_policy=config.scaling_policy, log_dir=config.log_dir, log_dumps_interval_ms=config.log_dumps_interval_ms) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( config_proto.SerializeToString()) if start: self._server.start()
def _start_dispatcher(self, worker_addresses, port=0): self._dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig( port=port, protocol="grpc", work_dir=self._work_dir, fault_tolerant_mode=True, worker_addresses=worker_addresses, deployment_mode=self._deployment_mode), start=True)