Esempio n. 1
0
def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
                     config):
    """Creates a `tf.train.ServerDef` protocol buffer.

  Args:
    server_or_cluster_def: A `tf.train.ServerDef` or
      `tf.train.ClusterDef` protocol buffer, or a
      `tf.train.ClusterSpec` object, describing the server to be
      defined and/or the cluster of which it is a member.
    job_name: (Optional.) Specifies the name of the job of which the server
      is a member. Defaults to the value in `server_or_cluster_def`, if
      specified.
    task_index: (Optional.) Specifies the task index of the server in its job.
      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
      defaults to 0 if the server's job has only one task.
    protocol: (Optional.) Specifies the protocol to be used by the server.
      Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
      in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
    config: (Options.) A `tf.ConfigProto` that specifies default configuration
      options for all sessions that run on this server.

  Returns:
    A `tf.train.ServerDef`.

  Raises:
    TypeError: If the arguments do not have the appropriate type.
    ValueError: If an argument is not specified and cannot be inferred.
  """
    server_def = tensorflow_server_pb2.ServerDef()
    if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
        server_def.MergeFrom(server_or_cluster_def)
        if job_name is not None:
            server_def.job_name = job_name
        if task_index is not None:
            server_def.task_index = task_index
        if protocol is not None:
            server_def.protocol = protocol
        if config is not None:
            server_def.default_session_config.MergeFrom(config)
    else:
        try:
            cluster_spec = ClusterSpec(server_or_cluster_def)
        except TypeError:
            raise TypeError("Could not convert `server_or_cluster_def` to a "
                            "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
        if job_name is None:
            if len(cluster_spec.jobs) == 1:
                job_name = cluster_spec.jobs[0]
            else:
                raise ValueError("Must specify an explicit `job_name`.")
        if task_index is None:
            task_indices = cluster_spec.task_indices(job_name)
            if len(task_indices) == 1:
                task_index = task_indices[0]
            else:
                raise ValueError("Must specify an explicit `task_index`.")
        if protocol is None:
            protocol = "grpc"

        server_def = tensorflow_server_pb2.ServerDef(
            cluster=cluster_spec.as_cluster_def(),
            job_name=job_name,
            task_index=task_index,
            protocol=protocol)
        if config is not None:
            server_def.default_session_config.MergeFrom(config)
    return server_def
  def _initialize_multi_worker(self, cluster_resolver):
    """Initializes the object for multi-worker training."""
    cluster_spec = multi_worker_util.normalize_cluster_spec(
        cluster_resolver.cluster_spec())
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    if task_type is None or task_id is None:
      raise ValueError("When `cluster_spec` is given, you must also specify "
                       "`task_type` and `task_id`.")
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id
    self._id_in_cluster = multi_worker_util.id_in_cluster(
        self._cluster_spec, self._task_type, self._task_id)

    self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
    if not self._num_workers:
      raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
                       "in `cluster_spec`.")

    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)

    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)

    if (ops.executing_eagerly_outside_functions() and
        not getattr(self, "_local_or_standalone_client_mode", False)):
      context.context().configure_collective_ops(
          collective_leader=multi_worker_util.collective_leader(
              cluster_spec, task_type, task_id),
          scoped_allocator_enabled_ops=("CollectiveReduce",),
          device_filters=("/job:%s/task:%d" % (task_type, task_id),))
      self._collective_ops_configured = True

    # Starting a std server in eager mode and in independent worker mode.
    if (context.executing_eagerly() and
        not getattr(self, "_std_server_started", False) and
        not getattr(self, "_local_or_standalone_client_mode", False)):
      # Checking _local_or_standalone_client_mode as well because we should not
      # create the std server in standalone client mode.
      config_proto = copy.deepcopy(context.context().config)
      config_proto = self._update_config_proto(config_proto)

      if hasattr(cluster_resolver, "port"):
        port = cluster_resolver.port
      else:
        port = 0
      server_def = tensorflow_server_pb2.ServerDef(
          cluster=cluster_spec.as_cluster_def(),
          default_session_config=config_proto,
          job_name=task_type,
          task_index=task_id,
          protocol=cluster_resolver.rpc_layer or "grpc",
          port=port)
      context.context().enable_collective_ops(server_def)
      self._std_server_started = True
      # The `ensure_initialized` is needed before calling
      # `context.context().devices()`.
      context.context().ensure_initialized()
      logging.info(
          "Enabled multi-worker collective ops with available devices: %r",
          context.context().devices())

    # TODO(yuefengz): The `num_gpus` is only for this particular task. It
    # assumes all workers have the same number of GPUs. We should remove this
    # assumption by querying all tasks for their numbers of GPUs.
    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
    # some cases.
    if isinstance(cluster_resolver, TFConfigClusterResolver):
      num_gpus = context.num_gpus()
    else:
      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

    if num_gpus:
      local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
                            for i in range(num_gpus))
    else:
      local_devices = (self._worker_device,)

    self._collective_keys = cross_device_utils.CollectiveKeys(
        group_key_start=1 + self._collective_key_base)
    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
        devices=local_devices,
        group_size=len(local_devices) * self._num_workers,
        collective_keys=self._collective_keys)
    # CrossDeviceOps for per host tensors.
    self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
        devices=[self._worker_device],
        group_size=self._num_workers,
        collective_keys=self._collective_keys)
    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
        local_devices)

    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = "/job:%s/task:%d" % (task_type, task_id)

    # Save the num_gpus_per_worker and rpc_layer for configure method.
    self._num_gpus_per_worker = num_gpus
    self._rpc_layer = cluster_resolver.rpc_layer
    self._warn_nccl_no_gpu()

    if self._enable_check_health:
      self._start_check_health_thread()

    logging.info(
        "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
        "task_id = %r, num_workers = %r, local_devices = %r, "
        "communication = %s", cluster_spec.as_dict(), task_type, task_id,
        self._num_workers, local_devices,
        self._communication_options.implementation)
Esempio n. 3
0
def initialize_multi_client_cluster(job_name: str,
                                    dtensor_jobs: List[str],
                                    client_id: int,
                                    collective_leader: str,
                                    port: Optional[int] = None,
                                    protocol: Optional[str] = "grpc+loas",
                                    enable_coordination_service: bool = False):
    """Initialize GRPC servers and collectives for multi-client DTensor setup.

  While single clients (e.g. Forge) can use local mode of collectives, GRPC
  servers are necessary in mutli-client setup. This function can be used to
  initialize a cluster and enable collective ops.

  NOTE: this function must be called in an eager context.

  Args:
    job_name: The job name used by all clients in the DTensor cluster.
    dtensor_jobs: A list of the DTensor client jobs participating in the
      cluster. Must be strings of the form "hostname:port".
    client_id: The ID of the DTensor client this function is being called in.
    collective_leader: The job/task that will be used to run collectives.
    port: The port this client's GRPC server will run on.
    protocol: The protocol to be used by this server.
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
  """
    assert context.executing_eagerly()

    if not collective_leader.startswith("/job:"):
        collective_leader = "/job:" + collective_leader

    context.context().configure_collective_ops(
        collective_leader=collective_leader)
    if enable_coordination_service:
        context.context().configure_coordination_service(
            service_type="standalone", service_leader=collective_leader)

    config_proto = context.get_config()
    config_proto.experimental.collective_group_leader = collective_leader
    # Construct server def from the host directly instead of relying on
    # TF_CONFIG.
    cluster_def = cluster_pb2.ClusterDef()
    # Note that we will currently rely on the sorted string of job name as the
    # order of assigning task ids. This might be brittle once we have jobs
    # across multiple cells.
    cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
    server_def = tensorflow_server_pb2.ServerDef(
        cluster=cluster_def,
        default_session_config=config_proto,
        job_name=job_name,
        task_index=client_id,
        protocol=protocol,
        port=port)
    server_def.default_session_config.rpc_options.num_channels_per_target = 4
    server_def.default_session_config.experimental.recv_buf_max_chunk = -1

    logging.info("Enabling collectives with server_def: %s", server_def)
    context.context().enable_collective_ops(server_def)
    context.ensure_initialized()
Esempio n. 4
0
    def _initialize_multi_worker(self, cluster_resolver):
        """Initializes the object for multi-worker training."""
        # TODO(yuefengz): The `num_gpus` is only for this particular task. It
        # assumes all workers have the same number of GPUs. We should remove this
        # assumption by querying all tasks for their numbers of GPUs.
        # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
        # some cases.
        if isinstance(cluster_resolver, TFConfigClusterResolver):
            num_gpus = context.num_gpus()
        else:
            num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

        cluster_spec = multi_worker_util.normalize_cluster_spec(
            cluster_resolver.cluster_spec())
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        if task_type is None or task_id is None:
            raise ValueError(
                "When `cluster_spec` is given, you must also specify "
                "`task_type` and `task_id` in the `cluster_resolver`.")
        if task_type not in ("chief", "worker"):
            raise ValueError(
                "Unrecognized task_type: %r, valid task types are: \"chief\", "
                "\"worker\"." % task_type)

        self._num_workers = multi_worker_util.worker_count(
            cluster_spec, task_type)
        if not self._num_workers:
            raise ValueError("No `worker` or `chief` tasks can be found in "
                             "`cluster_spec`.")

        self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                    task_id)

        self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
        self._host_input_device = numpy_dataset.SingleDevice(
            self._worker_device)
        if num_gpus:
            local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
                                  for i in range(num_gpus))
        else:
            local_devices = (self._worker_device, )

        self._collective_keys = cross_device_utils.CollectiveKeys()
        super(CollectiveAllReduceExtended,
              self)._initialize_local(local_devices)
        self._input_workers = input_lib.InputWorkers(
            self._device_map, [(self._worker_device, self.worker_devices)])
        self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
            num_workers=self._num_workers,
            num_gpus_per_worker=num_gpus,
            collective_keys=self._collective_keys)

        # Add a default device so that ops without specified devices will not end up
        # on other workers.
        self._default_device = "/job:%s/task:%d" % (task_type, task_id)

        self._cluster_spec = cluster_spec
        self._task_type = task_type
        self._task_id = task_id

        # Save the num_gpus_per_worker and rpc_layer for configure method.
        self._num_gpus_per_worker = num_gpus
        self._rpc_layer = cluster_resolver.rpc_layer

        logging.info(
            "Multi-worker CollectiveAllReduceStrategy with cluster_spec = %r, "
            "task_type = %r, task_id = %r, num_workers = %r, local_devices = %r, "
            "communication = %s", cluster_spec.as_dict(), task_type, task_id,
            self._num_workers, local_devices, self._communication)

        if (context.executing_eagerly()
                and not getattr(self, "_std_server_started", False) and
                not getattr(self, "_local_or_standalone_client_mode", False)):
            # Checking _local_or_standalone_client_mode as well because we should not
            # create the std server in standalone client mode.
            config_proto = config_pb2.ConfigProto()
            config_proto = self._update_config_proto(config_proto)
            server_def = tensorflow_server_pb2.ServerDef(
                cluster=cluster_spec.as_cluster_def(),
                default_session_config=config_proto,
                job_name=task_type,
                task_index=task_id,
                protocol=cluster_resolver.rpc_layer or "grpc")
            context.context().enable_collective_ops(server_def)
            self._std_server_started = True
            logging.info(
                "Enabled multi-worker collective ops with available devices: %r",
                context.context().devices())
 def setUp(self):
   self._cluster = tensorflow_server_pb2.ServerDef(protocol="grpc").cluster