Exemple #1
0
 def testDispatcherNumWorkers(self):
     dispatcher = server_lib.DispatchServer()
     self.assertEqual(0, dispatcher._num_workers())
     worker1 = server_lib.WorkerServer(  # pylint: disable=unused-variable
         server_lib.WorkerConfig(dispatcher._address))
     self.assertEqual(1, dispatcher._num_workers())
     worker2 = server_lib.WorkerServer(  # pylint: disable=unused-variable
         server_lib.WorkerConfig(dispatcher._address))
     self.assertEqual(2, dispatcher._num_workers())
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        dispatcher = server_lib.DispatchServer(
            server_lib.DispatcherConfig(port=dispatcher_port), start=False)
        worker = server_lib.WorkerServer(server_lib.WorkerConfig(
            dispatcher_address=_address_from_target(dispatcher.target),
            port=0),
                                         start=False)

        def start_servers():
            time.sleep(1)
            dispatcher.start()
            worker.start()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = _make_distributed_range_dataset(num_elements, dispatcher)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()
 def add_worker(self, start=True):
     self.workers.append(
         server_lib.WorkerServer(server_lib.WorkerConfig(
             dispatcher_address=self.dispatcher_address(),
             heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
             dispatcher_timeout_ms=1000),
                                 start=start))
Exemple #4
0
 def testStartWorkerWithPortConfig(self):
     dispatcher = server_lib.DispatchServer()
     port = pick_unused_port()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address, port=port),
                                      start=True)
     self.assertEqual(worker._address, "localhost:{}".format(port))
Exemple #5
0
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=dispatcher.target,
                task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))
 def testStopStartWorker(self):
   dispatcher = server_lib.DispatchServer()
   worker = server_lib.WorkerServer(
       server_lib.WorkerConfig(dispatcher._address))
   worker._stop()
   with self.assertRaisesRegex(
       RuntimeError, "Server cannot be started after it has been stopped"):
     worker.start()
 def testProfileWorker(self):
   dispatcher = server_lib.DispatchServer()
   worker = server_lib.WorkerServer(
       server_lib.WorkerConfig(dispatcher._address))
   # Test the profilers are successfully started and connected to profiler
   # service on the worker. Since there is no op running, it is expected to
   # return UnavailableError with no trace events collected string.
   with self.assertRaises(errors.UnavailableError) as error:
     profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
   self.assertStartsWith(str(error.exception), "No trace event was collected")
    def start_workers(self):
        self._workers = []
        for _ in range(self._num_workers):
            worker = server_lib.WorkerServer(server_lib.WorkerConfig(
                dispatcher_address=self._dispatcher_address),
                                             start=True)
            self._workers.append(worker)

        self._pipe_writer.send("Remote workers are ready.")
        for worker in self._workers:
            worker.join()
 def restart_worker(self, worker_index=0, use_same_port=True):
     """Replaces the worker at index `worker_index` with a new worker."""
     worker = self.workers[worker_index]
     port = 0
     if use_same_port:
         port = int(worker._address.split(":")[1])
     worker._stop()
     self.workers[worker_index] = server_lib.WorkerServer(
         server_lib.WorkerConfig(
             dispatcher_address=self.dispatcher_address(),
             port=port,
             heartbeat_interval_ms=worker._config.heartbeat_interval_ms))
Exemple #11
0
 def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
     dispatcher = self.start_dispatch_server(work_dir=work_dir,
                                             fault_tolerant_mode=False)
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher_address=self.dispatcher_address(dispatcher), port=0),
                                      start=False)
     # Larger than default OSS grpc message size limit of 4MB.
     tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
     ds = dataset_ops.Dataset.from_tensors(tensor)
     ds = self.make_distributed_dataset(ds, dispatcher)
     it = iter(ds)
     worker.start()
     self.assertAllEqual(next(it), tensor)
Exemple #12
0
def _make_worker(dispatcher_address, shutdown_quiet_period_ms=0, port=0):
    """Creates a worker server."""
    defaults = server_lib.WorkerConfig(dispatcher_address=dispatcher_address)
    config_proto = service_config_pb2.WorkerConfig(
        dispatcher_address=dispatcher_address,
        worker_address=defaults.worker_address,
        port=port,
        protocol=PROTOCOL,
        heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
        dispatcher_timeout_ms=TEST_DISPATCHER_TIMEOUT_MS,
        data_transfer_protocol=None,
        shutdown_quiet_period_ms=shutdown_quiet_period_ms)
    return server_lib.WorkerServer(config_proto, start=False)
 def start_worker_server(self, dispatcher, port=0):
     return server_lib.WorkerServer(
         server_lib.WorkerConfig(dispatcher_address=_address_from_target(
             dispatcher.target),
                                 port=port))
 def start_worker_server(self, dispatcher, port=0):
   return server_lib.WorkerServer(
       server_lib.WorkerConfig(
           dispatcher_address=self.dispatcher_address(dispatcher),
           port=port,
           heartbeat_interval_ms=200))
 def start_worker_server(self, dispatcher, port=0):
   return server_lib.WorkerServer(
       server_lib.WorkerConfig(
           dispatcher_address=_address_from_target(dispatcher.target),
           port=port,
           heartbeat_interval_ms=200))
Exemple #16
0
 def testJoinWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(
         server_lib.WorkerConfig(dispatcher._address))
     worker._stop()
     worker.join()
Exemple #17
0
 def testMultipleStartWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address),
                                      start=True)
     worker.start()
Exemple #18
0
 def testStartWorker(self):
     dispatcher = server_lib.DispatchServer()
     worker = server_lib.WorkerServer(server_lib.WorkerConfig(
         dispatcher._address),
                                      start=False)
     worker.start()