def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(port=dispatcher_port), start=False)
        worker = server_lib.WorkerServer(server_lib.WorkerConfig(
            dispatcher_address=_address_from_target(dispatcher.target),
            port=0),
                                         start=False)

        def start_servers():
            time.sleep(1)
            dispatcher.start()
            worker.start()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = _make_distributed_range_dataset(num_elements, dispatcher)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()
 def _start_dispatcher(self, worker_addresses):
     work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
     self._dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(port=0,
                                     work_dir=work_dir,
                                     protocol="grpc",
                                     worker_addresses=worker_addresses),
         start=True)
 def _start_dispatcher(self, worker_addresses, port=0):
   self._dispatcher = server_lib.DispatchServer(
       server_lib.DispatcherConfig(
           port=port,
           work_dir=self._work_dir,
           protocol="grpc",
           worker_addresses=worker_addresses,
           fault_tolerant_mode=True),
       start=True)
Esempio n. 4
0
 def restart_dispatcher(self):
     """Stops `dispatcher` and creates a new dispatcher with the same port."""
     port = int(self.dispatcher_address().split(":")[1])
     self.dispatcher._stop()
     self.dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(
             port=port,
             work_dir=self.dispatcher._config.work_dir,
             fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode
         ))
Esempio n. 5
0
    def __init__(self,
                 num_workers,
                 dispatcher_port=0,
                 work_dir=TMP_WORK_DIR,
                 fault_tolerant_mode=True,
                 job_gc_check_interval_ms=None,
                 job_gc_timeout_ms=None,
                 worker_shutdown_quiet_period_ms=0,
                 start=True,
                 data_transfer_protocol=None):
        """Creates a tf.data service test cluster.

    Args:
      num_workers: The number of workers to initially add to the cluster.
      dispatcher_port: The port to use for the dispatcher.
      work_dir: The work directory to use for the dispatcher. If set to
        `TMP_WORK_DIR`, the cluster will create a new temporary directory to use
        as the work directory. If set to `NO_WORK_DIR`, no work directory will
        be used.
      fault_tolerant_mode: Whether the dispatcher should write its state to a
        journal so that it can recover from restarts.
      job_gc_check_interval_ms: How often the dispatcher should scan through to
        delete old and unused jobs, in milliseconds.
      job_gc_timeout_ms: How long a job needs to be unused before it becomes a
        candidate for garbage collection, in milliseconds.
      worker_shutdown_quiet_period_ms: When shutting down a worker, how long to
        wait for the gRPC server to process the final requests.
      start: Whether to immediately start the servers in the cluster. If
        `False`, the servers can be started later by calling
        `start_dispatcher()` and `start_workers()`.
      data_transfer_protocol: (Optional.) The protocol to use for transferring
        data with the tf.data service. The default can controlled via
        tf_data_service_test_transfer_protocol flag.
    """
        if work_dir == TMP_WORK_DIR:
            work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
        self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms
        if not data_transfer_protocol:
            data_transfer_protocol = TRANSFER_PROTOCOL.value
        self._data_transfer_protocol = data_transfer_protocol
        self.dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(
                port=dispatcher_port,
                work_dir=work_dir,
                protocol=PROTOCOL,
                fault_tolerant_mode=fault_tolerant_mode,
                job_gc_check_interval_ms=job_gc_check_interval_ms,
                job_gc_timeout_ms=job_gc_timeout_ms),
            start=start)

        self.workers = []
        for _ in range(num_workers):
            self.add_worker(start=start)
 def start_dispatch_server(self,
                           name="",
                           port=0,
                           work_dir=None,
                           fault_tolerant_mode=True):
     # If a test starts multiple independent dispatch servers, it should give
     # them different `name` values.
     work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
                             name) if work_dir is None else work_dir
     return server_lib.DispatchServer(
         server_lib.DispatcherConfig(
             port=port,
             work_dir=work_dir,
             fault_tolerant_mode=fault_tolerant_mode))
Esempio n. 7
0
 def __init__(self,
              num_local_workers,
              num_remote_workers,
              dispatcher_port=0,
              worker_shutdown_quiet_period_ms=0):
     work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
     self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms
     self._dispatcher = server_lib.DispatchServer(
         server_lib.DispatcherConfig(port=dispatcher_port,
                                     work_dir=work_dir,
                                     protocol="grpc"),
         start=True)
     self._local_workers = self.start_local_workers(num_local_workers)
     self.start_remote_workers(num_remote_workers)
  def restart_dispatcher(self):
    """Stops `dispatcher` and creates a new dispatcher with the same port.

    Restarting is supported only when the dispatcher is configured with
    `fault_tolerant_mode=True`.
    """
    if not self.dispatcher._config.fault_tolerant_mode:
      raise ValueError(
          "Trying to restart the dispatcher without fault-tolerance.")
    port = int(self.dispatcher_address().split(":")[1])
    self.dispatcher._stop()
    self.dispatcher = server_lib.DispatchServer(
        server_lib.DispatcherConfig(
            port=port,
            work_dir=self.dispatcher._config.work_dir,
            fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode))
 def start_dispatch_server(self,
                           name="",
                           port=0,
                           work_dir=TMP_WORK_DIR,
                           fault_tolerant_mode=True,
                           job_gc_check_interval_ms=None,
                           job_gc_timeout_ms=None):
   # If a test starts multiple independent dispatch servers, it should give
   # them different `name` values.
   work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
                           name) if work_dir is TMP_WORK_DIR else work_dir
   return server_lib.DispatchServer(
       server_lib.DispatcherConfig(
           port=port,
           work_dir=work_dir,
           fault_tolerant_mode=fault_tolerant_mode,
           job_gc_check_interval_ms=job_gc_check_interval_ms,
           job_gc_timeout_ms=job_gc_timeout_ms))
    def __init__(self,
                 num_workers,
                 dispatcher_port=0,
                 work_dir=TMP_WORK_DIR,
                 fault_tolerant_mode=True,
                 job_gc_check_interval_ms=None,
                 job_gc_timeout_ms=None,
                 start=True):
        """Creates a tf.data service test cluster.

    Args:
      num_workers: The number of workers to initially add to the cluster.
      dispatcher_port: The port to use for the dispatcher.
      work_dir: The work directory to use for the dispatcher. If set to
        `TMP_WORK_DIR`, the cluster will create a new temporary directory to use
        as the work directory. If set to `NO_WORK_DIR`, no work directory will
        be used.
      fault_tolerant_mode: Whether the dispatcher should write its state to a
        journal so that it can recover from restarts.
      job_gc_check_interval_ms: How often the dispatcher should scan through to
        delete old and unused jobs, in milliseconds.
      job_gc_timeout_ms: How long a job needs to be unused before it becomes a
        candidate for garbage collection, in milliseconds.
      start: Whether to immediately start the servers in the cluster. If
        `False`, the servers can be started later by calling
        `start_dispatcher()` and `start_workers()`.
    """
        if work_dir == TMP_WORK_DIR:
            work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
        self.dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(
                port=dispatcher_port,
                work_dir=work_dir,
                fault_tolerant_mode=fault_tolerant_mode,
                job_gc_check_interval_ms=job_gc_check_interval_ms,
                job_gc_timeout_ms=job_gc_timeout_ms),
            start=start)

        self.workers = []
        for _ in range(num_workers):
            self.add_worker(start=start)
Esempio n. 11
0
 def testStartDispatcherWithWrongFaultTolerantConfig(self):
     config = server_lib.DispatcherConfig(fault_tolerant_mode=True)
     error = "Cannot enable fault tolerant mode without configuring a work_dir"
     with self.assertRaisesRegex(ValueError, error):
         dispatcher = server_lib.DispatchServer(  # pylint: disable=unused-variable
             config=config, start=True)
Esempio n. 12
0
 def testStartDispatcherWithFaultTolerantConfig(self):
     temp_dir = tempfile.mkdtemp()
     config = server_lib.DispatcherConfig(work_dir=temp_dir,
                                          fault_tolerant_mode=True)
     dispatcher = server_lib.DispatchServer(  # pylint: disable=unused-variable
         config=config, start=True)
Esempio n. 13
0
 def testStartDispatcherWithWorkDirConfig(self):
     temp_dir = tempfile.mkdtemp()
     config = server_lib.DispatcherConfig(work_dir=temp_dir)
     dispatcher = server_lib.DispatchServer(  # pylint: disable=unused-variable
         config=config, start=True)
Esempio n. 14
0
 def testStartDispatcherWithPortConfig(self):
     port = pick_unused_port()
     config = server_lib.DispatcherConfig(port=port)
     dispatcher = server_lib.DispatchServer(config=config, start=True)
     self.assertEqual(dispatcher.target, "grpc://localhost:{}".format(port))
Esempio n. 15
0
 def _start_dispatcher(self, dispatcher_port):
   work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
   self._dispatcher = server_lib.DispatchServer(
       server_lib.DispatcherConfig(
           port=dispatcher_port, work_dir=work_dir, protocol="grpc"),
       start=True)