示例#1
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The job name is used by TensorFlow in the job name section of
      the DeviceSpec.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    global _in_multi_client_mode
    assert context.executing_eagerly()

    _in_multi_client_mode = api.job_name() != 'localhost'

    if not _in_multi_client_mode and api.num_clients() != 1:
        raise ValueError(
            'DTENSOR_NUM_CLIENTS is set and not 1, while DTENSOR_JOB_NAME is '
            'set to localhost for single client mode.')

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if _in_multi_client_mode:
        if api.jobs() is None:
            raise ValueError(
                'DTENSOR_JOBS environment variable is required when'
                'using multi-client to properly set up communications between servers'
            )
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service,
            protocol='grpc')

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
示例#2
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process. The default value is 0.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients. The default value is 1.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The default is `localhost` when number of clients is 1, and `worker`
      when the number of clients is greater than 1.
      The job name controls the job name section of the TensorFlow DeviceSpecs,
      e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
      the job name is `worker`.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    assert context.executing_eagerly()

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if api.num_clients() > 1:
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service)

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
示例#3
0
def dtensor_initialize_tpu_system(enable_coordination_service=False):
  """Initialize the TPU devices.

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

  assert context.executing_eagerly()
  in_multi_client_mode = api.job_name() != "localhost"

  # Collective GRPC servers are only necessary in mutli-client setup.
  # Single clients (e.g. Forge) can use local mode of collectives.
  if in_multi_client_mode:
    if api.jobs() is None:
      raise ValueError(
          "DTENSOR_JOBS environment variable is required when"
          "using multi-client to properly set up communications between servers"
      )
    multi_client_util.initialize_multi_client_cluster(
        job_name=api.job_name(),
        dtensor_jobs=api.jobs(),
        client_id=api.client_id(),
        collective_leader=api.full_job_name(task_id=0),
        enable_coordination_service=enable_coordination_service)

  # Make sure the server change is fully propagated before attempting to run
  # the core ID merging logic below.
  context.ensure_initialized()
  context.async_wait()
  context.context()._clear_caches()  # pylint: disable=protected-access

  @function.defun
  def _tpu_init_fn():
    return gen_dtensor_ops.configure_and_initialize_global_tpu()

  try:
    with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"):  # pylint: disable=protected-access
      my_core_ids = _tpu_init_fn()
    logging.info("TPU core IDs: %s", my_core_ids)
    context.initialize_logical_devices()

    # Configure virtual CPUs that is 1:1 mapped to TPU cores.
    context.context().set_logical_cpu_devices(
        len(api.local_devices(_TPU_DEVICE_TYPE)),
        tf_device.DeviceSpec(
            job=api.job_name(), replica=0, task=api.client_id()).to_string())

    # `my_core_ids` contains the IDs of TPU cores attached to this host.
    #
    # To generate correct and efficient XLA AllReduce group assignment, we must
    # merge these arrays from all hosts and broadcast the result back to all
    # hosts, so all hosts can use these mappings in their MLIR passes.
    #
    # This is essentially doing what WaitForDistributedTpuOp and
    # SetGlobalTPUArrayOp do, in our multi-client environment.
    task_id = api.client_id()
    num_tasks = api.num_clients()
    num_devices = api.num_global_devices(_TPU_DEVICE_TYPE)
    num_devices_per_task = int(num_devices / num_tasks)

    # Create a one-time use mesh and layout just for merging core IDs.
    mesh = layout_lib.Mesh([_MESH_DIM_X],
                           *_create_device_array((num_devices,),
                                                 _TPU_DEVICE_TYPE,
                                                 api.client_id()))
    layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
    device = dtensor_device.DTensorDevice(meshes=[mesh])
    logging.info("TPU core locations: %s",
                 device.tpu_core_ids_to_locations(my_core_ids))

    # At this point, we don't know which cores are attached to other hosts.
    # The core ID mappings in the runtime haven't been set yet.
    #
    # The core ID merging AllReduce below is carefully written so it works
    # without needing correct core mappings to be set in the runtime. We will
    # use this AllReduce's result to set the core ID mappings, and all future
    # user-initiated AllReduces will use the mappings.
    #
    # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
    all_core_ids = np.zeros([num_devices], dtype=np.int32)
    for i in range(len(my_core_ids)):
      all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]

    # Only one local device gets valid input: 8 local core IDs among
    # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
    # The other 7 local devices get zero inputs. All devices on all host
    # participate in one AllReduce, whose result will be core IDs arranged by
    # task-device ordinals.
    all_core_ids = constant_op.constant([all_core_ids])
    zeros = array_ops.zeros_like(all_core_ids)
    all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)

    with ops.device(device.name):
      all_core_ids = device.pack(all_core_ids, layout)
      all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
      unpacked_all_tpu_ids = device.unpack(all_core_ids)

    all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
    logging.info("All TPU core IDs: %s", all_core_ids)

    # Set the default core ID mappings in the runtime for legacy code and tests.
    #
    # Legacy code and tests create TPU meshes directly without using the
    # `create_tpu_mesh` function below. Those meshes have global device IDs
    # equal to TF task-device ordinals. The `all_core_ids` array happens to
    # arrange core IDs by TF task-device ordinals. Using this array on those
    # meshes guarantee correct although inefficient results.
    device.set_tpu_core_ids("", all_core_ids)

    # Remember enough global, immutable information to be able to build any ring
    # we want prescribed by `create_tpu_mesh` in the future.
    global _all_core_ids
    _all_core_ids = all_core_ids

    all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
    all_core_locations = [
        _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
    ]
    global _all_core_locations
    _all_core_locations = all_core_locations
    logging.info("All TPU core locations: %s", all_core_locations)

    tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
                                        num_devices_per_task)
    global _tpu_topology
    _tpu_topology = tpu_topology
    logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
                 tpu_topology.device_coordinates)

    global _dtensor_device
    _dtensor_device = device

    context.async_wait()

  except errors.InvalidArgumentError as e:
    raise errors.NotFoundError(
        None, None, "Initialization failed, no valid TPUs found. " + str(e))

  except errors.InternalError as e:
    logging.error("Hit internal error during TPU system initialization. "
                  + "It is likely hareware failure. \nPlease check the error "
                  + "messages above to see whether that's the case. \nIf so, "
                  + "consider to restart the job or try another machine.")
    raise e

  # Optionally exchange heartbeats between workers every minute.
  if in_multi_client_mode and api.heartbeat_enabled():
    logging.info(
        "Starting DTensor heartbeat service exchanging signals every 10 minutes"
    )
    heartbeat.start(period=180)

  # Clear out the eager context caches since the memory is invalid now.
  logging.info("Clearing out eager caches")
  context.context()._clear_caches()  # pylint: disable=protected-access
示例#4
0
def initialize_multi_client_cluster(job_name: str,
                                    dtensor_jobs: List[str],
                                    client_id: int,
                                    collective_leader: str,
                                    port: Optional[int] = None,
                                    enable_coordination_service: bool = False):
    """Initialize GRPC servers and collectives for multi-client DTensor setup.

  While single clients (e.g. Forge) can use local mode of collectives, GRPC
  servers are necessary in mutli-client setup. This function can be used to
  initialize a cluster and enable collective ops.

  NOTE: this function must be called in an eager context.

  Args:
    job_name: The job name used by all clients in the DTensor cluster.
    dtensor_jobs: A list of the DTensor client jobs participating in the
      cluster. Must be strings of the form "hostname:port".
    client_id: The ID of the DTensor client this function is being called in.
    collective_leader: The job/task that will be used to run collectives.
    port: The port this client's GRPC server will run on.
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
  """
    global _is_multi_client_initialized
    assert context.executing_eagerly()

    if _is_multi_client_initialized:
        raise ValueError("Multi-client mode has already been initialized.")

    if api.num_clients() <= 1:
        raise ValueError(
            "DTENSOR_NUM_CLIENTS must be set greater than 1 for multi-client mode."
        )

    if not api.jobs() or len(api.jobs()) <= 1:
        raise ValueError(
            "DTENSOR_JOBS environment variable is required when using multi-client "
            "mode to properly set up communications between servers.")

    if len(api.jobs()) != api.num_clients():
        raise ValueError(
            "DTENSOR_JOBS environment variable must be configured with the same "
            "number of items as DTENSOR_NUM_CLIENTS.")

    if not collective_leader.startswith("/job:"):
        collective_leader = "/job:" + collective_leader

    context.context().configure_collective_ops(
        collective_leader=collective_leader)
    if enable_coordination_service:
        context.context().configure_coordination_service(
            service_type="standalone", service_leader=collective_leader)

    config_proto = context.get_config()
    config_proto.experimental.collective_group_leader = collective_leader
    # Construct server def from the host directly instead of relying on
    # TF_CONFIG.
    cluster_def = cluster_pb2.ClusterDef()
    # Note that we will currently rely on the sorted string of job name as the
    # order of assigning task ids. This might be brittle once we have jobs
    # across multiple cells.
    cluster_def.job.add(name=job_name, tasks=dict(enumerate(dtensor_jobs)))
    server_def = tensorflow_server_pb2.ServerDef(
        cluster=cluster_def,
        default_session_config=config_proto,
        job_name=job_name,
        task_index=client_id,
        protocol=remote_utils.get_default_communication_protocol(),
        port=port)
    server_def.default_session_config.rpc_options.num_channels_per_target = 4
    server_def.default_session_config.experimental.recv_buf_max_chunk = -1

    logging.info("Enabling collectives with server_def: %s", server_def)
    context.context().enable_collective_ops(server_def)
    context.ensure_initialized()

    _is_multi_client_initialized = True