Exemplo n.º 1
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
Exemplo n.º 2
0
    def __init__(self, dvariable, name):
        with ops.device(dvariable.device):
            original_layout = api.fetch_layout(dvariable)
        # Record original layout to allow restore.
        self._original_layout = original_layout
        self._dvariable = dvariable

        def pack(tensors, layout):
            with ops.device(dvariable.device):
                return api.pack(tensors, layout)

        host_layout = layout_lib.Layout(original_layout.sharding_specs,
                                        original_layout.mesh.host_mesh())

        def get_host_dvariable():
            # Copy to host mesh if needed.
            if original_layout.mesh.device_type().upper() != 'CPU':
                with ops.device(dvariable.device):
                    host_dvariable = DVariable(
                        api.pack(api.unpack(dvariable.read_value()),
                                 host_layout))
            else:
                host_dvariable = dvariable
            return (math_ops.cast(host_dvariable, dtypes.bfloat16)
                    if self.should_cast(host_dvariable) else host_dvariable)

        num_local_devices = original_layout.mesh.num_local_devices()
        super(_DVariableSaveable, self).__init__(
            None,
            [
                DSaveSpec(
                    tensor=get_host_dvariable,
                    slice_spec=pack([''] * num_local_devices,
                                    layout_lib.Layout.replicated(
                                        original_layout.mesh.host_mesh(),
                                        rank=0)),
                    name=pack([name] * num_local_devices,
                              layout_lib.Layout.replicated(
                                  original_layout.mesh.host_mesh(), rank=0)),
                    global_shape=dvariable.shape,
                    # Layout is attached as attribute, no need to put it as a
                    # Tensor on DTensorDevice.
                    layout=host_layout.to_string(),
                    dtype=dtypes.bfloat16
                    if self.should_cast(dvariable) else dvariable.dtype,
                    device=dvariable.device)
            ],
            name)
Exemplo n.º 3
0
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str):
    """Runs a global barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients.

  Currently we allocate a fully sharded tensor with mesh shape and run a
  all_reduce on it.

  Args:
    mesh: The mesh to run the global barrier on.
    last_op_name: The last op run before the global_barrier. mainly used for
      logging purpose.
  """
    logging.info('entering global barrier before op: %s', last_op_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    shape = api._dtensor_device().pack(  # pylint: disable=protected-access
        [mesh.shape()] * mesh.num_local_devices(),
        layout_lib.Layout.replicated(mesh, rank=1))
    ones = api.call_with_layout(array_ops.ones,
                                layout_lib.Layout(mesh.dim_names, mesh),
                                shape=shape,
                                dtype=dtypes.float32)
    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info(
        'finished running global barrier across all clients after '
        'op: %s', last_op_name)
Exemplo n.º 4
0
def dtensor_initialize_tpu_system(enable_coordination_service=False):
  """Initialize the TPU devices.

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

  assert context.executing_eagerly()
  in_multi_client_mode = api.job_name() != "localhost"

  # Collective GRPC servers are only necessary in mutli-client setup.
  # Single clients (e.g. Forge) can use local mode of collectives.
  if in_multi_client_mode:
    if api.jobs() is None:
      raise ValueError(
          "DTENSOR_JOBS environment variable is required when"
          "using multi-client to properly set up communications between servers"
      )
    multi_client_util.initialize_multi_client_cluster(
        job_name=api.job_name(),
        dtensor_jobs=api.jobs(),
        client_id=api.client_id(),
        collective_leader=api.full_job_name(task_id=0),
        enable_coordination_service=enable_coordination_service)

  # Make sure the server change is fully propagated before attempting to run
  # the core ID merging logic below.
  context.ensure_initialized()
  context.async_wait()
  context.context()._clear_caches()  # pylint: disable=protected-access

  @function.defun
  def _tpu_init_fn():
    return gen_dtensor_ops.configure_and_initialize_global_tpu()

  try:
    with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"):  # pylint: disable=protected-access
      my_core_ids = _tpu_init_fn()
    logging.info("TPU core IDs: %s", my_core_ids)
    context.initialize_logical_devices()

    # Configure virtual CPUs that is 1:1 mapped to TPU cores.
    context.context().set_logical_cpu_devices(
        len(api.local_devices(_TPU_DEVICE_TYPE)),
        tf_device.DeviceSpec(
            job=api.job_name(), replica=0, task=api.client_id()).to_string())

    # `my_core_ids` contains the IDs of TPU cores attached to this host.
    #
    # To generate correct and efficient XLA AllReduce group assignment, we must
    # merge these arrays from all hosts and broadcast the result back to all
    # hosts, so all hosts can use these mappings in their MLIR passes.
    #
    # This is essentially doing what WaitForDistributedTpuOp and
    # SetGlobalTPUArrayOp do, in our multi-client environment.
    task_id = api.client_id()
    num_tasks = api.num_clients()
    num_devices = api.num_global_devices(_TPU_DEVICE_TYPE)
    num_devices_per_task = int(num_devices / num_tasks)

    # Create a one-time use mesh and layout just for merging core IDs.
    mesh = layout_lib.Mesh([_MESH_DIM_X],
                           *_create_device_array((num_devices,),
                                                 _TPU_DEVICE_TYPE,
                                                 api.client_id()))
    layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
    device = dtensor_device.DTensorDevice(meshes=[mesh])
    logging.info("TPU core locations: %s",
                 device.tpu_core_ids_to_locations(my_core_ids))

    # At this point, we don't know which cores are attached to other hosts.
    # The core ID mappings in the runtime haven't been set yet.
    #
    # The core ID merging AllReduce below is carefully written so it works
    # without needing correct core mappings to be set in the runtime. We will
    # use this AllReduce's result to set the core ID mappings, and all future
    # user-initiated AllReduces will use the mappings.
    #
    # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
    all_core_ids = np.zeros([num_devices], dtype=np.int32)
    for i in range(len(my_core_ids)):
      all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]

    # Only one local device gets valid input: 8 local core IDs among
    # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
    # The other 7 local devices get zero inputs. All devices on all host
    # participate in one AllReduce, whose result will be core IDs arranged by
    # task-device ordinals.
    all_core_ids = constant_op.constant([all_core_ids])
    zeros = array_ops.zeros_like(all_core_ids)
    all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)

    with ops.device(device.name):
      all_core_ids = device.pack(all_core_ids, layout)
      all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
      unpacked_all_tpu_ids = device.unpack(all_core_ids)

    all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
    logging.info("All TPU core IDs: %s", all_core_ids)

    # Set the default core ID mappings in the runtime for legacy code and tests.
    #
    # Legacy code and tests create TPU meshes directly without using the
    # `create_tpu_mesh` function below. Those meshes have global device IDs
    # equal to TF task-device ordinals. The `all_core_ids` array happens to
    # arrange core IDs by TF task-device ordinals. Using this array on those
    # meshes guarantee correct although inefficient results.
    device.set_tpu_core_ids("", all_core_ids)

    # Remember enough global, immutable information to be able to build any ring
    # we want prescribed by `create_tpu_mesh` in the future.
    global _all_core_ids
    _all_core_ids = all_core_ids

    all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
    all_core_locations = [
        _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
    ]
    global _all_core_locations
    _all_core_locations = all_core_locations
    logging.info("All TPU core locations: %s", all_core_locations)

    tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
                                        num_devices_per_task)
    global _tpu_topology
    _tpu_topology = tpu_topology
    logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
                 tpu_topology.device_coordinates)

    global _dtensor_device
    _dtensor_device = device

    context.async_wait()

  except errors.InvalidArgumentError as e:
    raise errors.NotFoundError(
        None, None, "Initialization failed, no valid TPUs found. " + str(e))

  except errors.InternalError as e:
    logging.error("Hit internal error during TPU system initialization. "
                  + "It is likely hareware failure. \nPlease check the error "
                  + "messages above to see whether that's the case. \nIf so, "
                  + "consider to restart the job or try another machine.")
    raise e

  # Optionally exchange heartbeats between workers every minute.
  if in_multi_client_mode and api.heartbeat_enabled():
    logging.info(
        "Starting DTensor heartbeat service exchanging signals every 10 minutes"
    )
    heartbeat.start(period=180)

  # Clear out the eager context caches since the memory is invalid now.
  logging.info("Clearing out eager caches")
  context.context()._clear_caches()  # pylint: disable=protected-access