示例#1
0
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=dispatcher.target,
                task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))
示例#3
0
 def testStopStartDispatcher(self):
     dispatcher = server_lib.DispatchServer()
     dispatcher._stop()
     with self.assertRaisesRegex(
             RuntimeError,
             "Server cannot be started after it has been stopped"):
         dispatcher.start()
示例#4
0
 def testStopStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address)
   worker._stop()
   with self.assertRaisesRegex(
       RuntimeError, "Server cannot be started after it has been stopped"):
     worker.start()
示例#5
0
 def testStartWorkerWithPortConfig(self):
     dispatcher = server_lib.DispatchServer()
     port = pick_unused_port()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address, port=port),
                                      start=True)
     self.assertEqual(worker._address, "localhost:{}".format(port))
 def start_dispatch_server(self, name="", port=0):
     # If a test starts multiple independent dispatch servers, it should give
     # them different `name` values.
     work_dir = os.path.join(self.get_temp_dir(), "work_dir_", name)
     return server_lib.DispatchServer(port=port,
                                      protocol=server_lib.DEFAULT_PROTOCOL,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=True)
 def _start_dispatcher(self, worker_addresses):
     work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
     self._dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(port=0,
                                     work_dir=work_dir,
                                     protocol="grpc",
                                     worker_addresses=worker_addresses),
         start=True)
 def _start_dispatcher(self, worker_addresses, port=0):
   self._dispatcher = server_lib.DispatchServer(
       server_lib.DispatcherConfig(
           port=port,
           work_dir=self._work_dir,
           protocol="grpc",
           worker_addresses=worker_addresses,
           fault_tolerant_mode=True),
       start=True)
 def testProfileWorker(self):
   dispatcher = server_lib.DispatchServer()
   worker = server_lib.WorkerServer(
       server_lib.WorkerConfig(dispatcher._address))
   # Test the profilers are successfully started and connected to profiler
   # service on the worker. Since there is no op running, it is expected to
   # return UnavailableError with no trace events collected string.
   with self.assertRaises(errors.UnavailableError) as error:
     profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
   self.assertStartsWith(str(error.exception), "No trace event was collected")
示例#10
0
 def restart_dispatcher(self):
     """Stops `dispatcher` and creates a new dispatcher with the same port."""
     port = int(self.dispatcher_address().split(":")[1])
     self.dispatcher._stop()
     self.dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(
             port=port,
             work_dir=self.dispatcher._config.work_dir,
             fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode
         ))
示例#11
0
    def __init__(self,
                 num_workers,
                 dispatcher_port=0,
                 work_dir=TMP_WORK_DIR,
                 fault_tolerant_mode=True,
                 job_gc_check_interval_ms=None,
                 job_gc_timeout_ms=None,
                 worker_shutdown_quiet_period_ms=0,
                 start=True,
                 data_transfer_protocol=None):
        """Creates a tf.data service test cluster.

    Args:
      num_workers: The number of workers to initially add to the cluster.
      dispatcher_port: The port to use for the dispatcher.
      work_dir: The work directory to use for the dispatcher. If set to
        `TMP_WORK_DIR`, the cluster will create a new temporary directory to use
        as the work directory. If set to `NO_WORK_DIR`, no work directory will
        be used.
      fault_tolerant_mode: Whether the dispatcher should write its state to a
        journal so that it can recover from restarts.
      job_gc_check_interval_ms: How often the dispatcher should scan through to
        delete old and unused jobs, in milliseconds.
      job_gc_timeout_ms: How long a job needs to be unused before it becomes a
        candidate for garbage collection, in milliseconds.
      worker_shutdown_quiet_period_ms: When shutting down a worker, how long to
        wait for the gRPC server to process the final requests.
      start: Whether to immediately start the servers in the cluster. If
        `False`, the servers can be started later by calling
        `start_dispatcher()` and `start_workers()`.
      data_transfer_protocol: (Optional.) The protocol to use for transferring
        data with the tf.data service. The default can controlled via
        tf_data_service_test_transfer_protocol flag.
    """
        if work_dir == TMP_WORK_DIR:
            work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
        self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms
        if not data_transfer_protocol:
            data_transfer_protocol = TRANSFER_PROTOCOL.value
        self._data_transfer_protocol = data_transfer_protocol
        self.dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(
                port=dispatcher_port,
                work_dir=work_dir,
                protocol=PROTOCOL,
                fault_tolerant_mode=fault_tolerant_mode,
                job_gc_check_interval_ms=job_gc_check_interval_ms,
                job_gc_timeout_ms=job_gc_timeout_ms),
            start=start)

        self.workers = []
        for _ in range(num_workers):
            self.add_worker(start=start)
示例#12
0
 def _start_dispatcher(self, worker_addresses, port=0):
     if port == 0:
         port = test_util.pick_unused_port()
     self._dispatcher = server_lib.DispatchServer(
         service_config_pb2.DispatcherConfig(
             port=port,
             protocol="grpc",
             work_dir=self._work_dir,
             fault_tolerant_mode=True,
             worker_addresses=worker_addresses,
             deployment_mode=self._deployment_mode),
         start=True)
示例#13
0
 def __init__(self,
              num_local_workers,
              num_remote_workers,
              dispatcher_port=0,
              worker_shutdown_quiet_period_ms=0):
     work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
     self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms
     self._dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(port=dispatcher_port,
                                     work_dir=work_dir,
                                     protocol="grpc"),
         start=True)
     self._local_workers = self.start_local_workers(num_local_workers)
     self.start_remote_workers(num_remote_workers)
  def restart_dispatcher(self):
    """Stops `dispatcher` and creates a new dispatcher with the same port.

    Restarting is supported only when the dispatcher is configured with
    `fault_tolerant_mode=True`.
    """
    if not self.dispatcher._config.fault_tolerant_mode:
      raise ValueError(
          "Trying to restart the dispatcher without fault-tolerance.")
    port = int(self.dispatcher_address().split(":")[1])
    self.dispatcher._stop()
    self.dispatcher = server_lib.DispatchServer(
        server_lib.DispatcherConfig(
            port=port,
            work_dir=self.dispatcher._config.work_dir,
            fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode))
示例#15
0
  def create_cluster(self, num_workers):
    """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      A string for connecting to the tf.data service.
    """
    self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
    self._servers = []
    for _ in range(num_workers):
      self._servers.append(
          server_lib.WorkerServer(
              port=0,
              dispatcher_address=self._dispatcher._address,
              protocol=PROTOCOL))

    return "{0}://{1}".format(PROTOCOL, self._dispatcher._address)
示例#16
0
 def testDispatcherPreemption(self):
     self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
     self._worker = server_lib.WorkerServer(
         port=0,
         dispatcher_address=self._dispatcher._address,
         protocol=PROTOCOL)
     num_elements = 100
     ds = dataset_ops.Dataset.range(num_elements)
     ds = _make_distributed_dataset(
         ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
     iterator = iter(ds)
     results = []
     results.append(next(iterator).numpy())
     self._dispatcher._stop()
     # After the dispatcher dies, the worker should continue providing the rest
     # of the dataset's elements.
     for _ in range(num_elements - 1):
         results.append(next(iterator).numpy())
     self.assertEqual(results, list(range(num_elements)))
示例#17
0
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = _make_distributed_dataset(
            ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)
        self.assertEqual(0, self.evaluate(get_next()))
    def __init__(self,
                 num_workers,
                 dispatcher_port=0,
                 work_dir=TMP_WORK_DIR,
                 fault_tolerant_mode=True,
                 job_gc_check_interval_ms=None,
                 job_gc_timeout_ms=None,
                 start=True):
        """Creates a tf.data service test cluster.

    Args:
      num_workers: The number of workers to initially add to the cluster.
      dispatcher_port: The port to use for the dispatcher.
      work_dir: The work directory to use for the dispatcher. If set to
        `TMP_WORK_DIR`, the cluster will create a new temporary directory to use
        as the work directory. If set to `NO_WORK_DIR`, no work directory will
        be used.
      fault_tolerant_mode: Whether the dispatcher should write its state to a
        journal so that it can recover from restarts.
      job_gc_check_interval_ms: How often the dispatcher should scan through to
        delete old and unused jobs, in milliseconds.
      job_gc_timeout_ms: How long a job needs to be unused before it becomes a
        candidate for garbage collection, in milliseconds.
      start: Whether to immediately start the servers in the cluster. If
        `False`, the servers can be started later by calling
        `start_dispatcher()` and `start_workers()`.
    """
        if work_dir == TMP_WORK_DIR:
            work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
        self.dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(
                port=dispatcher_port,
                work_dir=work_dir,
                fault_tolerant_mode=fault_tolerant_mode,
                job_gc_check_interval_ms=job_gc_check_interval_ms,
                job_gc_timeout_ms=job_gc_timeout_ms),
            start=start)

        self.workers = []
        for _ in range(num_workers):
            self.add_worker(start=start)
示例#19
0
    def testRestartWorker(self, use_same_port):
        self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(
            ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            port = int(self._worker._address.split(":")[1])
        self._worker._stop()
        self._new_worker = server_lib.WorkerServer(
            port=port,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)
示例#20
0
 def testJoinWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address)
   worker._stop()
   worker.join()
示例#21
0
 def testJoinDispatcher(self):
   dispatcher = server_lib.DispatchServer(0)
   dispatcher._stop()
   dispatcher.join()
示例#22
0
 def testStopDispatcher(self):
   dispatcher = server_lib.DispatchServer(0)
   dispatcher._stop()
   dispatcher._stop()
示例#23
0
 def testMultipleStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
   worker.start()
示例#24
0
 def testStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
   worker.start()