def _connect_to_cluster(self, coordinator_name):
    if coordinator_name in ["worker", "ps"]:
      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
    cluster_spec = self._cluster_resolver.cluster_spec()
    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))

    device_filters = server_lib.ClusterDeviceFilters()
    # For any worker, only the devices on ps and coordinator nodes are visible
    for i in range(self._num_workers):
      device_filters.set_device_filters(
          "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
    # Similarly for any ps, only the devices on workers and coordinator are
    # visible
    for i in range(self._num_ps):
      device_filters.set_device_filters(
          "ps", i, ["/job:worker", "/job:%s" % coordinator_name])

    # Allow at most one outstanding RPC for each worker at a certain time. This
    # is to simplify worker failure handling in the runtime
    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"

    logging.info("%s is now connecting to cluster with cluster_spec: %r",
                 self.__class__.__name__, cluster_spec)
    remote.connect_to_cluster(
        cluster_spec,
        job_name=coordinator_name,
        protocol=self._cluster_resolver.rpc_layer,
        cluster_device_filters=device_filters)

    distribute_lib.distribution_strategy_replica_gauge.get_cell(
        "ps_strategy_num_workers").set(self._num_workers)
    distribute_lib.distribution_strategy_replica_gauge.get_cell(
        "ps_strategy_num_ps").set(self._num_ps)
예제 #2
0
    def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
        cluster_device_filters = server_lib.ClusterDeviceFilters()
        for i in range(2):
            cluster_device_filters.set_device_filters('my_worker', i,
                                                      ['/job:my_ps'])
            cluster_device_filters.set_device_filters('my_ps', i,
                                                      ['/job:my_worker'])
        remote.connect_to_cluster(
            self._cluster, cluster_device_filters=cluster_device_filters)

        with ops.device('/job:my_ps/task:0/device:CPU:0'):
            v1 = variables.Variable(initial_value=0)
        with ops.device('/job:my_ps/task:1/device:CPU:0'):
            v2 = variables.Variable(initial_value=10)

        @def_function.function
        def worker_fn():
            v1.assign_add(1)
            v2.assign_sub(2)
            return v1.read_value() + v2.read_value()

        with ops.device('/job:my_worker/task:0/device:CPU:0'):
            self.assertAllEqual(worker_fn(), 9)
        with ops.device('/job:my_worker/task:1/device:CPU:0'):
            self.assertAllEqual(worker_fn(), 8)

        # The following remote call would fail because the ps nodes cannot see each
        # other due to the device filters.
        with self.assertRaises(errors.InvalidArgumentError) as cm:
            with ops.device('/job:my_ps/task:0/device:CPU:0'):
                worker_fn().numpy()
        self.assertIn(
            '/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
            cm.exception.message)

        with self.assertRaises(errors.InvalidArgumentError) as cm:
            with ops.device('/job:my_ps/task:1/device:CPU:0'):
                worker_fn().numpy()
        self.assertIn(
            '/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
            cm.exception.message)

        with ops.device('/job:my_worker/task:0/device:CPU:0'):
            self.assertAllEqual(worker_fn(), 7)
        with ops.device('/job:my_worker/task:1/device:CPU:0'):
            self.assertAllEqual(worker_fn(), 6)
        # Explicitly delete variables to avoid triggering errors when being GC'ed in
        # subsequent tests.
        del v1, v2
예제 #3
0
  def __init__(self, cluster_resolver, client_name="chief"):
    """Initializes the cluster instance and connect to the remote cluster."""
    if client_name in ["worker", "ps"]:
      raise ValueError("Client name should not be 'worker' or 'ps'.")
    cluster_spec = cluster_resolver.cluster_spec()

    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
    device_filters = server_lib.ClusterDeviceFilters()
    # For any worker, only the devices on PS and chief nodes are visible
    for i in range(self._num_workers):
      device_filters.set_device_filters(
          "worker", i, ["/job:ps", "/job:%s" % client_name])
    # Similarly for any ps, only the devices on workers and chief are visible
    for i in range(self._num_ps):
      device_filters.set_device_filters(
          "ps", i, ["/job:worker", "/job:%s" % client_name])

    context.context().mirroring_policy = context.MIRRORING_ALL
    # Allow at most one outstanding RPC for each worker at a certain time. This
    # is to simplify worker failure handling in the runtime
    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
    remote.connect_to_cluster(cluster_spec,
                              job_name=client_name,
                              protocol=cluster_resolver.rpc_layer,
                              cluster_device_filters=device_filters)

    self._cancellation_mgr = cancellation.CancellationManager()
    self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
    self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
    worker_device_strings = [
        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
    ]
    self.workers = [
        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
    ]