コード例 #1
0
  def __init__(self, config=None, start=True):
    """Creates a new dispatch server.

    Args:
      config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
        configration. If `None`, the dispatcher will use default
        configuration values.
      start: (Optional.) Boolean, indicating whether to start the server after
        creating it. Defaults to True.
    """
    config = config or DispatcherConfig()
    if config.fault_tolerant_mode and not config.work_dir:
      raise ValueError(
          "Cannot enable fault tolerant mode without configuring a work dir. "
          "Make sure to set `work_dir` in the `config` object passed to "
          "`DispatcherServer`.")
    self._config = config
    if isinstance(config, service_config_pb2.DispatcherConfig):
      config_proto = config
    else:
      config_proto = service_config_pb2.DispatcherConfig(
          port=config.port,
          protocol=config.protocol,
          work_dir=config.work_dir,
          fault_tolerant_mode=config.fault_tolerant_mode,
          worker_addresses=config.worker_addresses,
          job_gc_check_interval_ms=config.job_gc_check_interval_ms,
          job_gc_timeout_ms=config.job_gc_timeout_ms)
    self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
        config_proto.SerializeToString())
    if start:
      self._server.start()
コード例 #2
0
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)
コード例 #3
0
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=dispatcher.target,
                task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))
コード例 #4
0
    def __init__(self, config=None, start=True):
        """Creates a new dispatch server.

    Args:
      config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
        configration. If `None`, the dispatcher will use default
        configuration values.
      start: (Optional.) Boolean, indicating whether to start the server after
        creating it. Defaults to True.
    """
        config = config or DispatcherConfig()
        if config.fault_tolerant_mode and not config.work_dir:
            raise ValueError(
                "Cannot enable fault tolerant mode without configuring a work_dir"
            )
        self._config = config
        config_proto = service_config_pb2.DispatcherConfig(
            port=config.port,
            protocol=config.protocol,
            work_dir=config.work_dir,
            fault_tolerant_mode=config.fault_tolerant_mode,
            job_gc_check_interval_ms=config.job_gc_check_interval_ms,
            job_gc_timeout_ms=config.job_gc_timeout_ms,
            cache_policy=config.cache_policy,
            cache_format=config.cache_format,
            cache_compression=config.cache_compression,
            cache_ops_parallelism=config.cache_ops_parallelism,
            cache_path=config.cache_path,
            scaling_policy=config.scaling_policy,
            log_dir=config.log_dir,
            log_dumps_interval_ms=config.log_dumps_interval_ms)
        self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
            config_proto.SerializeToString())
        if start:
            self._server.start()
コード例 #5
0
 def _start_dispatcher(self, worker_addresses, port=0):
   self._dispatcher = server_lib.DispatchServer(
       service_config_pb2.DispatcherConfig(
           port=port,
           protocol="grpc",
           work_dir=self._work_dir,
           fault_tolerant_mode=True,
           worker_addresses=worker_addresses,
           deployment_mode=self._deployment_mode),
       start=True)