def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None): """Runs a barrier on the mesh. Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run an all_reduce on it. Example: A barrier can be used before application exit to ensure completion of pending ops. ```python x = [1, 2, 3] x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1)) dtensor.barrier(mesh) # At this point all devices on all clients in the mesh have completed # operations before the barrier. Therefore it is OK to tear down the clients. sys.exit() ``` Args: mesh: The mesh to run the barrier on. barrier_name: The name of the barrier. mainly used for logging purpose. """ if barrier_name is None: barrier_name = '(barrier)' logging.info('entering barrier before op: %s', barrier_name) # Make sure all ops are consumed before running the sync. context.async_wait() # Reduction on a fully sharded tensor requires all devices to participate # and serves as a barrier on the mesh. component = array_ops.reshape(1.0, [1] * len(mesh.shape())) ones = api.pack([component] * mesh.num_local_devices(), layout.Layout(mesh.dim_names, mesh)) mesh_size = math_ops.reduce_sum(ones) if mesh_size != mesh.size: raise ValueError( 'Global barrier produced wrong mesh size : {0} while mesh has actual' 'size : {1}'.format(mesh_size, mesh.size)) # TODO(hthu): This isn't strictly needed but might cause confusing behaviors # from users. Consider dropping this if there is a `big` performance hit. context.async_wait() logging.info('finished running barrier across all clients after ' 'op: %s', barrier_name)
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str): """Runs a global barrier on the mesh. Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run a all_reduce on it. Args: mesh: The mesh to run the global barrier on. last_op_name: The last op run before the global_barrier. mainly used for logging purpose. """ logging.info('entering global barrier before op: %s', last_op_name) # Make sure all ops are consumed before running the sync. context.async_wait() shape = api._dtensor_device().pack( # pylint: disable=protected-access [mesh.shape()] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=1)) ones = api.call_with_layout(array_ops.ones, layout_lib.Layout(mesh.dim_names, mesh), shape=shape, dtype=dtypes.float32) mesh_size = math_ops.reduce_sum(ones) if mesh_size != mesh.size: raise ValueError( 'Global barrier produced wrong mesh size : {0} while mesh has actual' 'size : {1}'.format(mesh_size, mesh.size)) # TODO(hthu): This isn't strictly needed but might cause confusing behaviors # from users. Consider dropping this if there is a `big` performance hit. context.async_wait() logging.info( 'finished running global barrier across all clients after ' 'op: %s', last_op_name)