Ejemplo n.º 1
0
 def testMasterNumWorkers(self):
     master = server_lib.MasterServer(0)
     self.assertEqual(0, master._num_workers())
     worker1 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(1, master._num_workers())
     worker2 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(2, master._num_workers())
Ejemplo n.º 2
0
 def testDispatcherNumWorkers(self):
   dispatcher = server_lib.DispatchServer(0)
   self.assertEqual(0, dispatcher._num_workers())
   worker1 = server_lib.WorkerServer(0, dispatcher._address)  # pylint: disable=unused-variable
   self.assertEqual(1, dispatcher._num_workers())
   worker2 = server_lib.WorkerServer(0, dispatcher._address)  # pylint: disable=unused-variable
   self.assertEqual(2, dispatcher._num_workers())
    def testAddWorkerMidJob(self):
        self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(ds, self._dispatcher._address)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        self._new_worker = server_lib.WorkerServer(
            port=0,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)

        # Wait for the new worker to register with the dispatcher.
        while self._dispatcher._num_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)
Ejemplo n.º 4
0
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0, master_address=self._master._address, protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(ds, self._master._address)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            port = int(self._worker._address.split(":")[1])
        self._worker._stop()
        self._new_worker = server_lib.WorkerServer(
            port=port, master_address=self._master._address, protocol=PROTOCOL)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)
Ejemplo n.º 5
0
  def testRestartWorker(self, use_same_port):
    self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
    self._worker = server_lib.WorkerServer(
        port=0, master_address=self._master._address, protocol=PROTOCOL)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = _make_distributed_dataset(ds, self._master._address)
    iterator = iter(ds)
    # Read halfway through the dataset.
    midpoint = num_elements // 2
    for i in range(midpoint):
      self.assertEqual(i, next(iterator).numpy())

    # Stop the original worker and start a new one.
    port = 0
    if use_same_port:
      port = int(self._worker._address.split(":")[1])
    self._worker._stop()
    self._new_worker = server_lib.WorkerServer(
        port=port, master_address=self._master._address, protocol=PROTOCOL)

    # The dataset starts over now that we read from the new worker.
    for i in range(num_elements):
      val = next(iterator).numpy()
      if val == midpoint and i != midpoint:
        # There may have been one last element prefetched from the first worker
        # before it was stopped.
        val = next(iterator).numpy()
      self.assertEqual(i, val)
Ejemplo n.º 6
0
 def testStartWorkerWithPortConfig(self):
     dispatcher = server_lib.DispatchServer()
     port = pick_unused_port()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address, port=port),
                                      start=True)
     self.assertEqual(worker._address, "localhost:{}".format(port))
Ejemplo n.º 7
0
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)
Ejemplo n.º 8
0
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=dispatcher.target,
                task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))
Ejemplo n.º 9
0
 def add_worker(self, start=True):
     self.workers.append(
         server_lib.WorkerServer(server_lib.WorkerConfig(
             dispatcher_address=self.dispatcher_address(),
             heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
             dispatcher_timeout_ms=1000),
                                 start=start))
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(port=dispatcher_port), start=False)
        worker = server_lib.WorkerServer(server_lib.WorkerConfig(
            dispatcher_address=_address_from_target(dispatcher.target),
            port=0),
                                         start=False)

        def start_servers():
            time.sleep(1)
            dispatcher.start()
            worker.start()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = _make_distributed_range_dataset(num_elements, dispatcher)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()
Ejemplo n.º 11
0
 def testStopStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address)
   worker._stop()
   with self.assertRaisesRegex(
       RuntimeError, "Server cannot be started after it has been stopped"):
     worker.start()
Ejemplo n.º 12
0
 def testProfileWorker(self):
   dispatcher = server_lib.DispatchServer()
   worker = server_lib.WorkerServer(
       server_lib.WorkerConfig(dispatcher._address))
   # Test the profilers are successfully started and connected to profiler
   # service on the worker. Since there is no op running, it is expected to
   # return UnavailableError with no trace events collected string.
   with self.assertRaises(errors.UnavailableError) as error:
     profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
   self.assertStartsWith(str(error.exception), "No trace event was collected")
Ejemplo n.º 13
0
    def start_workers(self):
        self._workers = []
        for _ in range(self._num_workers):
            worker = server_lib.WorkerServer(server_lib.WorkerConfig(
                dispatcher_address=self._dispatcher_address),
                                             start=True)
            self._workers.append(worker)

        self._pipe_writer.send("Remote workers are ready.")
        for worker in self._workers:
            worker.join()
 def restart_worker(self, worker_index=0, use_same_port=True):
     """Replaces the worker at index `worker_index` with a new worker."""
     worker = self.workers[worker_index]
     port = 0
     if use_same_port:
         port = int(worker._address.split(":")[1])
     worker._stop()
     self.workers[worker_index] = server_lib.WorkerServer(
         server_lib.WorkerConfig(
             dispatcher_address=self.dispatcher_address(),
             port=port,
             heartbeat_interval_ms=worker._config.heartbeat_interval_ms))
Ejemplo n.º 15
0
def _make_worker(dispatcher_address, shutdown_quiet_period_ms=0, port=0):
    """Creates a worker server."""
    defaults = server_lib.WorkerConfig(dispatcher_address=dispatcher_address)
    config_proto = service_config_pb2.WorkerConfig(
        dispatcher_address=dispatcher_address,
        worker_address=defaults.worker_address,
        port=port,
        protocol=PROTOCOL,
        heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
        dispatcher_timeout_ms=TEST_DISPATCHER_TIMEOUT_MS,
        data_transfer_protocol=None,
        shutdown_quiet_period_ms=shutdown_quiet_period_ms)
    return server_lib.WorkerServer(config_proto, start=False)
Ejemplo n.º 16
0
 def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
     dispatcher = self.start_dispatch_server(work_dir=work_dir,
                                             fault_tolerant_mode=False)
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher_address=self.dispatcher_address(dispatcher), port=0),
                                      start=False)
     # Larger than default OSS grpc message size limit of 4MB.
     tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
     ds = dataset_ops.Dataset.from_tensors(tensor)
     ds = self.make_distributed_dataset(ds, dispatcher)
     it = iter(ds)
     worker.start()
     self.assertAllEqual(next(it), tensor)
Ejemplo n.º 17
0
  def create_cluster(self, num_workers):
    """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      The address of the master.
    """
    self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
    self._servers = []
    for _ in range(num_workers):
      self._servers.append(
          server_lib.WorkerServer(
              port=0, master_address=self._master._address, protocol=PROTOCOL))

    return self._master._address
Ejemplo n.º 18
0
 def testDispatcherPreemption(self):
     self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
     self._worker = server_lib.WorkerServer(
         port=0,
         dispatcher_address=self._dispatcher._address,
         protocol=PROTOCOL)
     num_elements = 100
     ds = dataset_ops.Dataset.range(num_elements)
     ds = _make_distributed_dataset(
         ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
     iterator = iter(ds)
     results = []
     results.append(next(iterator).numpy())
     self._dispatcher._stop()
     # After the dispatcher dies, the worker should continue providing the rest
     # of the dataset's elements.
     for _ in range(num_elements - 1):
         results.append(next(iterator).numpy())
     self.assertEqual(results, list(range(num_elements)))
Ejemplo n.º 19
0
  def create_cluster(self, num_workers):
    """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      A string for connecting to the tf.data service.
    """
    self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
    self._servers = []
    for _ in range(num_workers):
      self._servers.append(
          server_lib.WorkerServer(
              port=0,
              dispatcher_address=self._dispatcher._address,
              protocol=PROTOCOL))

    return "{0}://{1}".format(PROTOCOL, self._dispatcher._address)
Ejemplo n.º 20
0
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0,
            dispatcher_address=self._dispatcher._address,
            protocol=PROTOCOL)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = _make_distributed_dataset(
            ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)
        self.assertEqual(0, self.evaluate(get_next()))
Ejemplo n.º 21
0
 def testJoinWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address)
   worker._stop()
   worker.join()
 def start_worker_server(self, dispatcher, port=0):
     return server_lib.WorkerServer(
         server_lib.WorkerConfig(dispatcher_address=_address_from_target(
             dispatcher.target),
                                 port=port))
Ejemplo n.º 23
0
 def testMultipleStartWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address),
                                      start=True)
     worker.start()
Ejemplo n.º 24
0
 def start_worker_server(self, dispatcher, port=0):
     return server_lib.WorkerServer(port=port,
                                    dispatcher_address=_address_from_target(
                                        dispatcher.target),
                                    protocol=server_lib.DEFAULT_PROTOCOL)
Ejemplo n.º 25
0
 def testStopWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(
         server_lib.WorkerConfig(dispatcher._address))
     worker._stop()
     worker._stop()
Ejemplo n.º 26
0
 def start_worker_server(self, dispatcher, port=0):
   return server_lib.WorkerServer(
       server_lib.WorkerConfig(
           dispatcher_address=self.dispatcher_address(dispatcher),
           port=port,
           heartbeat_interval_ms=200))
Ejemplo n.º 27
0
 def testStartWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address),
                                      start=False)
     worker.start()
 def start_worker_server(self, dispatcher, port=0):
   return server_lib.WorkerServer(
       server_lib.WorkerConfig(
           dispatcher_address=_address_from_target(dispatcher.target),
           port=port,
           heartbeat_interval_ms=200))
Ejemplo n.º 29
0
 def testMultipleStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
   worker.start()
Ejemplo n.º 30
0
 def testStartWorker(self):
   dispatcher = server_lib.DispatchServer(0)
   worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
   worker.start()