Ejemplo n.º 1
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
Ejemplo n.º 2
0
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str):
    """Runs a global barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients.

  Currently we allocate a fully sharded tensor with mesh shape and run a
  all_reduce on it.

  Args:
    mesh: The mesh to run the global barrier on.
    last_op_name: The last op run before the global_barrier. mainly used for
      logging purpose.
  """
    logging.info('entering global barrier before op: %s', last_op_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    shape = api._dtensor_device().pack(  # pylint: disable=protected-access
        [mesh.shape()] * mesh.num_local_devices(),
        layout_lib.Layout.replicated(mesh, rank=1))
    ones = api.call_with_layout(array_ops.ones,
                                layout_lib.Layout(mesh.dim_names, mesh),
                                shape=shape,
                                dtype=dtypes.float32)
    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info(
        'finished running global barrier across all clients after '
        'op: %s', last_op_name)