Esempio n. 1
0
def get_device_ids(mesh: layout_lib.Mesh,
                   client_id: Optional[int] = None) -> List[int]:
  """Returns the device IDs of all TPU cores local to the given client.

  A device ID is a non-negative integer that uniquely identifies a device in the
  mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a
  permutation of [0, 1, 2, 3].

  Note that device IDs and device locations are equivalent. The former is a
  linearization of the latter along mesh dimensions.

  Args:
    mesh: A TPU mesh.
    client_id: Optional; A DTensor client ID. If empty, query this client.
  """

  if mesh.device_type() != _TPU_DEVICE_TYPE:
    raise ValueError("The mesh must be a TPU mesh")

  if client_id is None or client_id == api.client_id():
    return mesh.local_device_ids()

  # It's not clear we should ever allow a client to query other clients for
  # their device IDs.
  raise NotImplementedError(
      "Looking up other clients' device IDs is not supported")
Esempio n. 2
0
def get_device_locations(
    mesh: layout_lib.Mesh,
    client_id: Optional[int] = None) -> List[Dict[str, int]]:
  """Returns the device locations of all TPU cores local to the given client.

  A device location is a dictionary from dimension names to indices on those
  dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a
  permutation of this list:

    [{'x': 0, 'y': 0},
     {'x': 0, 'y': 1},
     {'x': 1, 'y': 0},
     {'x': 1, 'y': 1}].

  Note that device IDs and device locations are equivalent. The former is a
  linearization of the latter along mesh dimensions.

  Args:
    mesh: A TPU mesh.
    client_id: Optional; A DTensor client ID. If empty, query this client.
  """

  if mesh.device_type() != _TPU_DEVICE_TYPE:
    raise ValueError("The mesh must be a TPU mesh")

  if client_id is None or client_id == api.client_id():
    return mesh.local_device_locations()

  # It's not clear we should ever allow a client to query other clients for
  # their device locations.
  raise NotImplementedError(
      "Looking up other clients' device locations is not supported")
Esempio n. 3
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The job name is used by TensorFlow in the job name section of
      the DeviceSpec.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    global _in_multi_client_mode
    assert context.executing_eagerly()

    _in_multi_client_mode = api.job_name() != 'localhost'

    if not _in_multi_client_mode and api.num_clients() != 1:
        raise ValueError(
            'DTENSOR_NUM_CLIENTS is set and not 1, while DTENSOR_JOB_NAME is '
            'set to localhost for single client mode.')

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if _in_multi_client_mode:
        if api.jobs() is None:
            raise ValueError(
                'DTENSOR_JOBS environment variable is required when'
                'using multi-client to properly set up communications between servers'
            )
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service,
            protocol='grpc')

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
Esempio n. 4
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process. The default value is 0.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients. The default value is 1.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The default is `localhost` when number of clients is 1, and `worker`
      when the number of clients is greater than 1.
      The job name controls the job name section of the TensorFlow DeviceSpecs,
      e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
      the job name is `worker`.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    assert context.executing_eagerly()

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if api.num_clients() > 1:
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service)

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
Esempio n. 5
0
def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
                            mesh_name: str = '',
                            num_global_devices: Optional[int] = None,
                            num_clients: Optional[int] = None,
                            client_id: Optional[int] = None,
                            device_type: str = 'CPU') -> layout.Mesh:
    """Creates a single- or multi-client mesh.

  For CPU and GPU meshes, users can choose to use fewer local devices than what
  is available. If any argument is missing, it will be extracted from
  environment variables. The default values for these environment variables
  create a mesh using all devices (common for unit tests).

  For TPU meshes, users should not specify any of the nullable arguments. The
  DTensor runtime will set these arguments automatically, using all TPU cores
  available in the entire cluster.

  Args:
    mesh_dims: A list of (dim_name, dim_size) tuples.
    mesh_name: Name of the created mesh. Defaults to ''.
    num_global_devices: Number of devices in the DTensor cluster. Defaults to
      the corresponding environment variable.
    num_clients: Number of clients in the DTensor cluster. Defaults to the
      corresponding environment variable, DTENSOR_NUM_CLIENTS.
    client_id: This client's ID. Defaults to the corresponding environment
      variable, DTENSOR_CLIENT_ID.
    device_type: Type of device to build the mesh for. Defaults to 'CPU'.

  Returns:
    A mesh created from specified or default arguments.
  """
    dim_names, shape = zip(*mesh_dims)

    if device_type.upper() in ['CPU', 'GPU']:
        # For CPU and GPU meshes, user-specified args take precedence over env vars.
        # This is particularly useful on single clients when users want to create
        # meshes that use fewer logical devices than what's available.

        if num_global_devices is None:
            num_global_devices = api.num_global_devices(device_type)
        if num_global_devices <= 0:
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be > 0')
        if num_global_devices != np.prod(shape):
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be '
                f'equal to total size of the mesh of shape {shape}')

        if num_clients is None:
            num_clients = api.num_clients()
        if num_clients <= 0:
            raise ValueError(f'num_clients ({num_clients}) must be > 0')

        if _in_multi_client_mode is None and num_clients > 1:
            raise ValueError(
                'Invalid multi-client topology, run dtensor.initialize_multi_client() first'
            )

        if client_id is None:
            client_id = api.client_id()
        if client_id < 0:
            raise ValueError(f'client_id ({client_id}) must be >= 0')
        if client_id >= num_clients:
            raise ValueError(
                f'client_id ({client_id}) must be < {num_clients}')

        if num_global_devices % num_clients != 0:
            raise ValueError(
                f'num_global_devices ({num_global_devices}) must be '
                f'divisible by num_clients ({num_clients})')
        num_local_devices = num_global_devices // num_clients

        # It's allowed to create a CPU or GPU mesh using fewer logical devices than
        # what's available. If so, just use the first N logical devices.
        num_available_devices = api.num_local_devices(device_type)
        if num_local_devices > num_available_devices:
            raise ValueError(
                f'Not enough devices; {num_local_devices} needed, '
                f'only {num_available_devices} available')
        local_devices = api.local_devices(device_type,
                                          client_id)[:num_local_devices]

        global_device_ids = np.arange(num_global_devices).reshape(shape)
        flattened = np.ravel(global_device_ids).tolist()
        start_idx = num_local_devices * client_id
        local_device_ids = flattened[start_idx:start_idx + num_local_devices]

        mesh = layout.Mesh(dim_names=dim_names,
                           global_device_ids=global_device_ids,
                           local_device_ids=local_device_ids,
                           local_devices=local_devices,
                           mesh_name=mesh_name)
        _print_context(num_global_devices, num_clients, client_id, device_type,
                       mesh)
        return mesh

    if device_type.upper() == 'TPU':
        # TPU meshes can only be configured through environment variables that
        # reflect the actual TPU topology. Do not let users specify custom args.
        if num_global_devices is not None:
            raise ValueError(
                f'Do not specify num_global_devices for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        if num_clients is not None:
            raise ValueError(
                f'Do not specify num_clients for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        if client_id is not None:
            raise ValueError(
                f'Do not specify client_id for {device_type.upper()} meshes. '
                'It will be filled in automatically from environmental variables.'
                'See api.py for the list of environmental variables for DTensor.'
            )
        mesh = tpu_util.create_tpu_mesh(dim_names, shape, mesh_name)
        _print_context(api.num_global_devices(device_type), api.num_clients(),
                       api.client_id(), device_type, mesh)
        return mesh

    raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
Esempio n. 6
0
def create_tpu_mesh(mesh_dim_names: List[str],
                    mesh_shape: List[int],
                    mesh_name: str,
                    ring_dims: Optional[int] = None,
                    ring_axes: Optional[List[str]] = None,
                    ring_bounds: Optional[List[int]] = None,
                    can_split_host_across_rings: bool = True,
                    build_ring_across_rings: bool = False,
                    rotate_ring_across_rings: bool = False) -> layout_lib.Mesh:
  """Returns a TPU mesh optimized for AllReduce ring reductions.

  Only as many as leading axes specified by `ring_axes` as necessary will be
  used to build rings, as long as the subslice formed by these axes have enough
  cores to contain a ring of the required size. The leftover axes in `ring_axes`
  won't affect results.

  See go/dtensor-device-assignment-api for details and performance tuning tips.

  Args:
    mesh_dim_names: List of mesh dimension names.
    mesh_shape: Shape of the mesh.
    mesh_name: A unique name for the mesh. If empty, internally generate one.
    ring_dims: Optional; The number of leading (ring_dims > 0) or trailing
      (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build
      rings for all but the first dimension.
    ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying
      the order of TPU topology axes to build rings in. If unspecified, default
      to ["core", "x", "y", "z"].
    ring_bounds: Optional; The maximum number of devices on each axis, in the x,
      y, z, core order. If unspecified, default to physical topology limits.
    can_split_host_across_rings: Optional; If true, devices attached to the same
      host (i.e., DTensor client) may get assigned to different rings. Setting
      it to false may cause some combinations of arguments to be infeasible; see
      DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples.
    build_ring_across_rings: Optional; If true, also build a data-parallel ring
      across model-parallel rings. This ring could be strided.
    rotate_ring_across_rings: Optional; If true, build the data-parallel ring in
      column-major instead of row-major order.
  """

  logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape)
  logging.info("Requested ring_dims: %s", ring_dims)
  logging.info("Requested ring_axes: %s", ring_axes)
  logging.info("Requested ring_bounds: %s", ring_bounds)
  logging.info("Requested can_split_host_across_rings: %s",
               can_split_host_across_rings)
  if not mesh_name:
    mesh_name = "mesh_%f" % time.time()
  logging.info("Requested mesh_name: %s", mesh_name)

  # By default, build rings for all but the first (usually batch) dimension.
  if ring_dims is None:
    ring_dims = 1 - len(mesh_shape)
  elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape):
    raise ValueError("Invalid ring_dims value: %d" % ring_dims)
  logging.info("Actual ring_dims: %s", ring_dims)

  # By default, vary axes in the core -> x -> y -> z order.
  if ring_axes is None:
    ring_axes = ["core", "x", "y", "z"]
  elif len(ring_axes) != 4:
    raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes)
  elif sorted(ring_axes) != ["core", "x", "y", "z"]:
    raise ValueError("Invalid ring_axes value: %s" % ring_axes)
  logging.info("Actual ring_axes: %s", ring_axes)

  # Validate ring_bounds values.
  global _tpu_topology
  if _tpu_topology is None:
    raise ValueError(
        "Invalid TPU topology, run dtensor_initialize_tpu_system() first")
  topology_shape = list(_tpu_topology.mesh_shape)
  if ring_bounds is None:
    ring_bounds = topology_shape
  elif len(ring_bounds) != 4:
    raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds)
  elif ring_bounds > topology_shape:
    raise ValueError("ring_bounds %s should be <= topology sizes %s" %
                     (ring_bounds, topology_shape))
  logging.info("Actual ring_bounds: %s", ring_bounds)

  # Compute ring_size, the number of cores in a ring.
  if ring_dims > 0:
    ring_size = np.prod(mesh_shape[:ring_dims])
  elif ring_dims < 0:
    ring_size = np.prod(mesh_shape[ring_dims:])
  else:
    ring_size = 1  # single-core rings
  logging.info("Actual ring_size: %d", ring_size)

  # Rearrange all cores according to the axis iteration order.
  global_core_locations = _enumerate_core_locations(
      topology_shape, ring_bounds, ring_axes, can_split_host_across_rings,
      ring_size)
  logging.vlog(1, "Enumerated core locations: %s", global_core_locations)
  num_cores = len(global_core_locations)

  # The mesh to be created must use all TPU cores in the system.
  mesh_size = np.prod(mesh_shape)
  if mesh_size != num_cores:
    raise ValueError(
        "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" %
        (mesh_shape, num_cores))

  # Build a ring for the `ring_size` dimension and, if required, a strided ring
  # for the orthogonal dimension.
  if build_ring_across_rings:
    global_core_locations = _build_orthogonal_rings(global_core_locations,
                                                    ring_size,
                                                    rotate_ring_across_rings)
  else:
    permutation = _build_all_reduce_ring(global_core_locations[:ring_size])
    for r in range(0, num_cores, ring_size):
      global_core_locations[r:r + ring_size] = [
          global_core_locations[r + permutation[i]] for i in range(ring_size)
      ]
    logging.vlog(1, "Permutated core locations: %s", global_core_locations)

  # For this point on, change from List[CoreLocation] to List[List[int]] for
  # easier interaction with the C++ API.
  global_core_locations = [l.to_list() for l in global_core_locations]
  global _dtensor_device
  if _dtensor_device is None:
    raise ValueError(
        "Invalid system device, run dtensor_initialize_tpu_system() first")
  global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
      global_core_locations)

  # Store a per-mesh mapping in the runtime.
  _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids)

  # Create the mesh by manually specifying local_device_ids.
  local_core_locations = _tpu_topology.device_coordinates[api.client_id()]
  indexes = [
      global_core_locations.index(list(local_core_location))
      for local_core_location in local_core_locations
  ]
  global_device_ids, local_device_ids, local_device_list = _create_device_array(
      mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes)
  return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids,
                         local_device_list, mesh_name)
Esempio n. 7
0
def dtensor_initialize_tpu_system(enable_coordination_service=False):
  """Initialize the TPU devices.

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

  assert context.executing_eagerly()
  in_multi_client_mode = api.job_name() != "localhost"

  # Collective GRPC servers are only necessary in mutli-client setup.
  # Single clients (e.g. Forge) can use local mode of collectives.
  if in_multi_client_mode:
    if api.jobs() is None:
      raise ValueError(
          "DTENSOR_JOBS environment variable is required when"
          "using multi-client to properly set up communications between servers"
      )
    multi_client_util.initialize_multi_client_cluster(
        job_name=api.job_name(),
        dtensor_jobs=api.jobs(),
        client_id=api.client_id(),
        collective_leader=api.full_job_name(task_id=0),
        enable_coordination_service=enable_coordination_service)

  # Make sure the server change is fully propagated before attempting to run
  # the core ID merging logic below.
  context.ensure_initialized()
  context.async_wait()
  context.context()._clear_caches()  # pylint: disable=protected-access

  @function.defun
  def _tpu_init_fn():
    return gen_dtensor_ops.configure_and_initialize_global_tpu()

  try:
    with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"):  # pylint: disable=protected-access
      my_core_ids = _tpu_init_fn()
    logging.info("TPU core IDs: %s", my_core_ids)
    context.initialize_logical_devices()

    # Configure virtual CPUs that is 1:1 mapped to TPU cores.
    context.context().set_logical_cpu_devices(
        len(api.local_devices(_TPU_DEVICE_TYPE)),
        tf_device.DeviceSpec(
            job=api.job_name(), replica=0, task=api.client_id()).to_string())

    # `my_core_ids` contains the IDs of TPU cores attached to this host.
    #
    # To generate correct and efficient XLA AllReduce group assignment, we must
    # merge these arrays from all hosts and broadcast the result back to all
    # hosts, so all hosts can use these mappings in their MLIR passes.
    #
    # This is essentially doing what WaitForDistributedTpuOp and
    # SetGlobalTPUArrayOp do, in our multi-client environment.
    task_id = api.client_id()
    num_tasks = api.num_clients()
    num_devices = api.num_global_devices(_TPU_DEVICE_TYPE)
    num_devices_per_task = int(num_devices / num_tasks)

    # Create a one-time use mesh and layout just for merging core IDs.
    mesh = layout_lib.Mesh([_MESH_DIM_X],
                           *_create_device_array((num_devices,),
                                                 _TPU_DEVICE_TYPE,
                                                 api.client_id()))
    layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
    device = dtensor_device.DTensorDevice(meshes=[mesh])
    logging.info("TPU core locations: %s",
                 device.tpu_core_ids_to_locations(my_core_ids))

    # At this point, we don't know which cores are attached to other hosts.
    # The core ID mappings in the runtime haven't been set yet.
    #
    # The core ID merging AllReduce below is carefully written so it works
    # without needing correct core mappings to be set in the runtime. We will
    # use this AllReduce's result to set the core ID mappings, and all future
    # user-initiated AllReduces will use the mappings.
    #
    # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
    all_core_ids = np.zeros([num_devices], dtype=np.int32)
    for i in range(len(my_core_ids)):
      all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]

    # Only one local device gets valid input: 8 local core IDs among
    # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
    # The other 7 local devices get zero inputs. All devices on all host
    # participate in one AllReduce, whose result will be core IDs arranged by
    # task-device ordinals.
    all_core_ids = constant_op.constant([all_core_ids])
    zeros = array_ops.zeros_like(all_core_ids)
    all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)

    with ops.device(device.name):
      all_core_ids = device.pack(all_core_ids, layout)
      all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
      unpacked_all_tpu_ids = device.unpack(all_core_ids)

    all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
    logging.info("All TPU core IDs: %s", all_core_ids)

    # Set the default core ID mappings in the runtime for legacy code and tests.
    #
    # Legacy code and tests create TPU meshes directly without using the
    # `create_tpu_mesh` function below. Those meshes have global device IDs
    # equal to TF task-device ordinals. The `all_core_ids` array happens to
    # arrange core IDs by TF task-device ordinals. Using this array on those
    # meshes guarantee correct although inefficient results.
    device.set_tpu_core_ids("", all_core_ids)

    # Remember enough global, immutable information to be able to build any ring
    # we want prescribed by `create_tpu_mesh` in the future.
    global _all_core_ids
    _all_core_ids = all_core_ids

    all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
    all_core_locations = [
        _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
    ]
    global _all_core_locations
    _all_core_locations = all_core_locations
    logging.info("All TPU core locations: %s", all_core_locations)

    tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
                                        num_devices_per_task)
    global _tpu_topology
    _tpu_topology = tpu_topology
    logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
                 tpu_topology.device_coordinates)

    global _dtensor_device
    _dtensor_device = device

    context.async_wait()

  except errors.InvalidArgumentError as e:
    raise errors.NotFoundError(
        None, None, "Initialization failed, no valid TPUs found. " + str(e))

  except errors.InternalError as e:
    logging.error("Hit internal error during TPU system initialization. "
                  + "It is likely hareware failure. \nPlease check the error "
                  + "messages above to see whether that's the case. \nIf so, "
                  + "consider to restart the job or try another machine.")
    raise e

  # Optionally exchange heartbeats between workers every minute.
  if in_multi_client_mode and api.heartbeat_enabled():
    logging.info(
        "Starting DTensor heartbeat service exchanging signals every 10 minutes"
    )
    heartbeat.start(period=180)

  # Clear out the eager context caches since the memory is invalid now.
  logging.info("Clearing out eager caches")
  context.context()._clear_caches()  # pylint: disable=protected-access
Esempio n. 8
0
def start(period: int) -> threading.Event:
    """Starts a persistent thread exchanging heartbeats between workers.

  Args:
    period: Heartbeat interval in seconds. Heartbeat timeout is set to the
      larger of `period` - 10 and 2s.

  Returns:
    A threading.Event object. Users can choose to call its set() method to shut
    down the heartbeat service gracefully. This isn't necessary in most cases,
    because the heartbeat service automatically shuts down at successful program
    exit through atexit handlers. But in situations when atexit handlers are not
    invoked, such as when multiprocessing processes exit in tests, users can
    manually request a shutdown.
  """
    global _heartbeat_timer
    if _heartbeat_timer is not None:
        logging.warning(
            'A heartbeat thread is already running, skipping this one.')
        return _heartbeat_timer

    task_id = api.client_id()
    num_tasks = api.num_clients()

    # Worker 0 generates a random token. All other workers receive that token.
    if task_id == 0:
        token = np.random.randint(0,
                                  pow(2, 16) - 1)  # reserve the other 16 bits
        signal = np.full([num_tasks], token, dtype=np.int32)
    else:
        signal = np.zeros([num_tasks], dtype=np.int32)
    logging.info('Initial heartbeat signal: %s', signal)

    device = tf_device.DeviceSpec(job=api.job_name(),
                                  replica=0,
                                  task=task_id,
                                  device_type='CPU',
                                  device_index=0)
    # Always use 0 for group and instance keys to reduce unnecessary
    # collective hangs and simplify failure analysis. This also avoid
    # collision with normal collectives.
    with ops.device(device):
        signal = all_reduce(constant_op.constant(signal),
                            group_size=num_tasks,
                            group_key=0,
                            instance_key=0,
                            timeout=max(period - 10, 2)).numpy()
    logging.info('Merged heartbeat signal %s', signal)

    # The merged signal should have equal elements. If not, some worker(s) may be
    # out of sync, and we should terminate all workers.
    if task_id == 0:
        if not np.all(signal == token):
            logging.fatal('Merged heartbeat signal has value != %d', token)
    else:
        if len(set(signal)) != 1:
            logging.fatal('Merged heartbeat signal has unequal elements')
        token = signal[0]

    # On normal main process exit, set the timer to stop the heartbeat thread.
    _heartbeat_timer = threading.Event()

    def stop_heartbeat():
        logging.info('Stopping the heartbeat thread')
        _heartbeat_timer.set()
        # Give the threads some time to clean up.
        time.sleep(max(period // 10, 2))

    atexit.register(stop_heartbeat)

    # Start the persistent heartbeat thread.
    thread = threading.Thread(
        target=_heartbeat,
        args=[period, _heartbeat_timer, token, num_tasks, task_id, device],
        daemon=True)
    thread.start()

    return _heartbeat_timer
Esempio n. 9
0
    def __init__(self,
                 dataset: dataset_ops.DatasetV2,
                 *,
                 mesh: layout_lib.Mesh,
                 layouts: Any,
                 global_batch_size: int,
                 dataset_already_batched: bool = False,
                 batch_dim: Optional[str] = None,
                 prefetch: Optional[int] = None,
                 tf_data_service_config: Optional[TFDataServiceConfig] = None):
        """Creates a DTensorDataset.

    DTensorDataset automatically handles distribution of the dataset elements to
    each client's devices. It can be used to create an iterator that returns
    DTensors of the input data on each iteration.

    DTensorDataset works best with unbatched datasets. It takes the mesh and the
    provided layouts to automatically calculate how to batch the input locally
    for each replica.

    If the provided dataset is already batched according to the per-replica
    batch size, then `dataset_already_batched` must be set and DTensorDataset
    will check that the batch size is consistent with the intended
    `global_batch_size` using the layout information. Each replica receives a
    separate slice of the global batch, thus the per-replica batch size can be
    computed as the global batch size divided by the number of model replicas.
    For a DTensor mesh, the number of replicas is equal to the size of the
    mesh's batch dimension.

    TODO(b/223275517): add support for input datasets that are already batched
    to the global batch size.

    Args:
      dataset: a `tf.data.Dataset` object.
      mesh: the DTensor mesh to place the dataset batches on.
      layouts: a structure of DTensor layouts to be applied to the input dataset
        values. This can be a single layout or (possibly nested) tuples or
        dictionaries of layouts, and the structure must match the structure of
        the dataset. Either all or none of the layouts should be sharded on the
        batch dimension; having only a subset of layouts batch sharded will not
        work and raises a ValueError.
      global_batch_size: the desired global batch size.
      dataset_already_batched: must be set only if the dataset is already
        batched to the per-replica batch size. The batched dataset must have
        `drop_remainder=True` set since DTensor requires static shapes for
        slicing the input tensors.
      batch_dim: the mesh dimension on which the input's batch dimension is
        sharded. Set to None if the input layouts do not shard on the batch
        dimension.
      prefetch: number of batches to prefetch using Dataset.prefetch.
      tf_data_service_config: if operating in multi-client mode, this config
        specifies the tf.data service configuration to use.

    Raises:
      ValueError: on any of the following situations,
        1. if the structures and ranks of layouts and the dataset do not match.
        2. if the shapes in the dataset's spec are not fully defined.
        3. if batch_dim is specified and all layouts are not batch-sharded.
        4. if per_replica_batch_size is specified for an already batched Dataset
           but it does not match the expected per-replica size based on the
           provided mesh.
      TypeError: if type of structures of layouts and the dataset do not match.
    """
        super().__init__(dataset, dataset_ops.to_variant(dataset))

        self._mesh = mesh
        self._layouts = layouts
        self._batch_dim = batch_dim
        self._prefetch = prefetch
        self._tf_data_service_config = tf_data_service_config

        self._element_spec = dataset.element_spec

        nest.assert_same_structure(self._element_spec, self._layouts)
        flattened_layouts = nest.flatten(self._layouts)
        flattened_elem_spec = nest.flatten(self._element_spec)

        if batch_dim:
            num_global_replicas = mesh.dim_size(batch_dim)
            self._local_replica_ids = list(
                dict.fromkeys(
                    [loc[batch_dim] for loc in mesh.local_device_locations()]))

            for layout in flattened_layouts:
                if batch_dim != layout.sharding_specs[0]:
                    raise ValueError((
                        'batch_dim %s was specified but at least one layout did not '
                        'contain it: %s') % (batch_dim, layout))
        else:
            # Only one replica since there is no sharding on the batch dimension.
            num_global_replicas = 1
            self._local_replica_ids = [0]

        # Validate layout and element spec compatibility, and raise ValueError if
        # invalid.
        _validate_input(flattened_layouts,
                        flattened_elem_spec,
                        dataset_already_batched=dataset_already_batched)

        expected_batch_size = global_batch_size // num_global_replicas
        if not dataset_already_batched:
            self._batched_dataset = dataset.batch(expected_batch_size,
                                                  drop_remainder=True)
        else:
            per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0]
            if per_replica_batch_size != expected_batch_size:
                raise ValueError((
                    'per_replica_batch_size does not matched expected size based on '
                    'the mesh, got %d but expected %d.') %
                                 (per_replica_batch_size, expected_batch_size))
            self._batched_dataset = dataset

        num_global_devices_per_replica = api.num_global_devices(
            mesh.device_type()) // num_global_replicas
        self._num_local_replicas = len(self._local_replica_ids)
        self._num_local_devices_per_replica = mesh.num_local_devices(
        ) // self._num_local_replicas
        # The number of clients each replica is split over.
        self._num_clients_per_replica = (num_global_devices_per_replica //
                                         self._num_local_devices_per_replica)
        # In the case where a replica is split across multiple clients, an offset
        # needs to be added to the index used by the partitioning logic such that
        # the local devices on that client can be correctly matched to slices of the
        # input tensor(s). If replicas are wholly contained within a client, then
        # this offset is always 0.
        self._partition_offset = (api.client_id() %
                                  self._num_clients_per_replica
                                  ) * self._num_local_devices_per_replica

        # Helper data structures used in partitioning the dataset tensors.
        self._all_shard_counts = [
            _shard_counts(layout, batch_dim) for layout in flattened_layouts
        ]
        self._index_matrices = [
            _index_matrix(layout, elem_spec) for layout, elem_spec in zip(
                flattened_layouts, flattened_elem_spec)
        ]
Esempio n. 10
0
def create_distributed_mesh(mesh_dims: List[Tuple[str, int]],
                            mesh_name: str = '',
                            local_devices: Optional[List[str]] = None,
                            device_type: Optional[str] = None) -> layout.Mesh:
    """Creates a distributed mesh.

  This is similar to `create_mesh`, but with a different set of arguments to
  create a mesh that spans evenly across a multi-client DTensor cluster.

  For CPU and GPU meshes, users can choose to use fewer local devices than what
  is available `local_devices`.

  For TPU, only meshes that uses all TPU cores is supported by the DTensor
  runtime.

  Args:
    mesh_dims: A list of (dim_name, dim_size) tuples.
    mesh_name: Name of the created mesh. Defaults to ''.
    local_devices: String representations of devices to use. This is the device
      part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available local
      logical devices.
    device_type: Type of device to build the mesh for. Defaults to 'CPU'.
      Supported values are 'CPU', 'GPU', 'TPU'.

  Returns:
    A mesh that spans evenly across all DTensor clients in the cluster.
  """
    dim_names, shape = zip(*mesh_dims)

    if device_type and device_type.upper() == 'TPU':
        # TODO(b/185940495): Allow multi-mesh and partial on TPU.
        # TPU meshes can only be configured through environment variables that
        # reflect the actual TPU topology. Do not let users specify custom args.
        if local_devices is not None:
            raise ValueError(
                f'Do not specify devices for {device_type.upper()} meshes. '
                f'Using a partial list of devices for {device_type.upper()} '
                f'is not supported.')

    device_specs, device_type = _make_device_specs(local_devices, device_type)

    if device_type.upper() in ['CPU', 'GPU']:
        # For CPU and GPU meshes, user-specified args take precedence over env vars.
        # This is particularly useful on single clients when users want to create
        # meshes that use fewer logical devices than what's available.

        if api.num_clients() > 1 and not multi_client_util.is_initialized():
            raise ValueError('Invalid multi-client topology, please run '
                             'dtensor.initialize_multi_client() first.')

        local_spec = tf_device.DeviceSpec(job=api.job_name(),
                                          replica=0,
                                          task=api.client_id())
        device_specs = [local_spec.make_merged_spec(d) for d in device_specs]

        # Assumes identical number of local devices per client.
        num_global_devices = len(device_specs) * api.num_clients()

        if np.prod(shape) != num_global_devices:
            raise ValueError(
                f'Global number of devices '
                f'({len(device_specs)} per client * {api.num_clients()} clients '
                f'= {num_global_devices}) must be '
                f'equal to total size of the mesh of shape {shape}')

        global_device_ids = np.arange(num_global_devices).reshape(shape)
        flattened = np.ravel(global_device_ids).tolist()
        start_idx = len(device_specs) * api.client_id()
        local_device_ids = flattened[start_idx:start_idx + len(device_specs)]

        mesh = layout.Mesh(dim_names=dim_names,
                           global_device_ids=global_device_ids,
                           local_device_ids=local_device_ids,
                           local_devices=device_specs,
                           mesh_name=mesh_name)
        _print_context(num_global_devices, api.num_clients(), api.client_id(),
                       device_type, mesh)
        return mesh

    if device_type.upper() == 'TPU':
        mesh = tpu_util.create_tpu_mesh(dim_names, shape, mesh_name)
        _print_context(api.num_global_devices(device_type), api.num_clients(),
                       api.client_id(), device_type, mesh)
        return mesh

    raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')