def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None): """Runs a barrier on the mesh. Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run an all_reduce on it. Example: A barrier can be used before application exit to ensure completion of pending ops. ```python x = [1, 2, 3] x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1)) dtensor.barrier(mesh) # At this point all devices on all clients in the mesh have completed # operations before the barrier. Therefore it is OK to tear down the clients. sys.exit() ``` Args: mesh: The mesh to run the barrier on. barrier_name: The name of the barrier. mainly used for logging purpose. """ if barrier_name is None: barrier_name = '(barrier)' logging.info('entering barrier before op: %s', barrier_name) # Make sure all ops are consumed before running the sync. context.async_wait() # Reduction on a fully sharded tensor requires all devices to participate # and serves as a barrier on the mesh. component = array_ops.reshape(1.0, [1] * len(mesh.shape())) ones = api.pack([component] * mesh.num_local_devices(), layout.Layout(mesh.dim_names, mesh)) mesh_size = math_ops.reduce_sum(ones) if mesh_size != mesh.size: raise ValueError( 'Global barrier produced wrong mesh size : {0} while mesh has actual' 'size : {1}'.format(mesh_size, mesh.size)) # TODO(hthu): This isn't strictly needed but might cause confusing behaviors # from users. Consider dropping this if there is a `big` performance hit. context.async_wait() logging.info('finished running barrier across all clients after ' 'op: %s', barrier_name)
def __init__(self, dvariable, name): with ops.device(dvariable.device): original_layout = api.fetch_layout(dvariable) # Record original layout to allow restore. self._original_layout = original_layout self._dvariable = dvariable def pack(tensors, layout): with ops.device(dvariable.device): return api.pack(tensors, layout) host_layout = layout_lib.Layout(original_layout.sharding_specs, original_layout.mesh.host_mesh()) def get_host_dvariable(): # Copy to host mesh if needed. if original_layout.mesh.device_type().upper() != 'CPU': with ops.device(dvariable.device): host_dvariable = DVariable( api.pack(api.unpack(dvariable.read_value()), host_layout)) else: host_dvariable = dvariable return (math_ops.cast(host_dvariable, dtypes.bfloat16) if self.should_cast(host_dvariable) else host_dvariable) num_local_devices = original_layout.mesh.num_local_devices() super(_DVariableSaveable, self).__init__( None, [ DSaveSpec( tensor=get_host_dvariable, slice_spec=pack([''] * num_local_devices, layout_lib.Layout.replicated( original_layout.mesh.host_mesh(), rank=0)), name=pack([name] * num_local_devices, layout_lib.Layout.replicated( original_layout.mesh.host_mesh(), rank=0)), global_shape=dvariable.shape, # Layout is attached as attribute, no need to put it as a # Tensor on DTensorDevice. layout=host_layout.to_string(), dtype=dtypes.bfloat16 if self.should_cast(dvariable) else dvariable.dtype, device=dvariable.device) ], name)
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str): """Runs a global barrier on the mesh. Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run a all_reduce on it. Args: mesh: The mesh to run the global barrier on. last_op_name: The last op run before the global_barrier. mainly used for logging purpose. """ logging.info('entering global barrier before op: %s', last_op_name) # Make sure all ops are consumed before running the sync. context.async_wait() shape = api._dtensor_device().pack( # pylint: disable=protected-access [mesh.shape()] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=1)) ones = api.call_with_layout(array_ops.ones, layout_lib.Layout(mesh.dim_names, mesh), shape=shape, dtype=dtypes.float32) mesh_size = math_ops.reduce_sum(ones) if mesh_size != mesh.size: raise ValueError( 'Global barrier produced wrong mesh size : {0} while mesh has actual' 'size : {1}'.format(mesh_size, mesh.size)) # TODO(hthu): This isn't strictly needed but might cause confusing behaviors # from users. Consider dropping this if there is a `big` performance hit. context.async_wait() logging.info( 'finished running global barrier across all clients after ' 'op: %s', last_op_name)
def dtensor_initialize_tpu_system(enable_coordination_service=False): """Initialize the TPU devices. Args: enable_coordination_service: If true, enable distributed coordination service to make sure that workers know the devices on each other, a prerequisite for data transfer through cross-worker rendezvous. Raises: RuntimeError: If running inside a tf.function. NotFoundError: If no TPU devices found in eager mode. """ assert context.executing_eagerly() in_multi_client_mode = api.job_name() != "localhost" # Collective GRPC servers are only necessary in mutli-client setup. # Single clients (e.g. Forge) can use local mode of collectives. if in_multi_client_mode: if api.jobs() is None: raise ValueError( "DTENSOR_JOBS environment variable is required when" "using multi-client to properly set up communications between servers" ) multi_client_util.initialize_multi_client_cluster( job_name=api.job_name(), dtensor_jobs=api.jobs(), client_id=api.client_id(), collective_leader=api.full_job_name(task_id=0), enable_coordination_service=enable_coordination_service) # Make sure the server change is fully propagated before attempting to run # the core ID merging logic below. context.ensure_initialized() context.async_wait() context.context()._clear_caches() # pylint: disable=protected-access @function.defun def _tpu_init_fn(): return gen_dtensor_ops.configure_and_initialize_global_tpu() try: with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access my_core_ids = _tpu_init_fn() logging.info("TPU core IDs: %s", my_core_ids) context.initialize_logical_devices() # Configure virtual CPUs that is 1:1 mapped to TPU cores. context.context().set_logical_cpu_devices( len(api.local_devices(_TPU_DEVICE_TYPE)), tf_device.DeviceSpec( job=api.job_name(), replica=0, task=api.client_id()).to_string()) # `my_core_ids` contains the IDs of TPU cores attached to this host. # # To generate correct and efficient XLA AllReduce group assignment, we must # merge these arrays from all hosts and broadcast the result back to all # hosts, so all hosts can use these mappings in their MLIR passes. # # This is essentially doing what WaitForDistributedTpuOp and # SetGlobalTPUArrayOp do, in our multi-client environment. task_id = api.client_id() num_tasks = api.num_clients() num_devices = api.num_global_devices(_TPU_DEVICE_TYPE) num_devices_per_task = int(num_devices / num_tasks) # Create a one-time use mesh and layout just for merging core IDs. mesh = layout_lib.Mesh([_MESH_DIM_X], *_create_device_array((num_devices,), _TPU_DEVICE_TYPE, api.client_id())) layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh) device = dtensor_device.DTensorDevice(meshes=[mesh]) logging.info("TPU core locations: %s", device.tpu_core_ids_to_locations(my_core_ids)) # At this point, we don't know which cores are attached to other hosts. # The core ID mappings in the runtime haven't been set yet. # # The core ID merging AllReduce below is carefully written so it works # without needing correct core mappings to be set in the runtime. We will # use this AllReduce's result to set the core ID mappings, and all future # user-initiated AllReduces will use the mappings. # # The runtime is hard-coded to ignore core ID mappings on this AllReduce. all_core_ids = np.zeros([num_devices], dtype=np.int32) for i in range(len(my_core_ids)): all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i] # Only one local device gets valid input: 8 local core IDs among # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset. # The other 7 local devices get zero inputs. All devices on all host # participate in one AllReduce, whose result will be core IDs arranged by # task-device ordinals. all_core_ids = constant_op.constant([all_core_ids]) zeros = array_ops.zeros_like(all_core_ids) all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1) with ops.device(device.name): all_core_ids = device.pack(all_core_ids, layout) all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0]) unpacked_all_tpu_ids = device.unpack(all_core_ids) all_core_ids = list(unpacked_all_tpu_ids[0].numpy()) logging.info("All TPU core IDs: %s", all_core_ids) # Set the default core ID mappings in the runtime for legacy code and tests. # # Legacy code and tests create TPU meshes directly without using the # `create_tpu_mesh` function below. Those meshes have global device IDs # equal to TF task-device ordinals. The `all_core_ids` array happens to # arrange core IDs by TF task-device ordinals. Using this array on those # meshes guarantee correct although inefficient results. device.set_tpu_core_ids("", all_core_ids) # Remember enough global, immutable information to be able to build any ring # we want prescribed by `create_tpu_mesh` in the future. global _all_core_ids _all_core_ids = all_core_ids all_core_locations = device.tpu_core_ids_to_locations(all_core_ids) all_core_locations = [ _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations ] global _all_core_locations _all_core_locations = all_core_locations logging.info("All TPU core locations: %s", all_core_locations) tpu_topology = _create_tpu_topology(all_core_locations, num_tasks, num_devices_per_task) global _tpu_topology _tpu_topology = tpu_topology logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape, tpu_topology.device_coordinates) global _dtensor_device _dtensor_device = device context.async_wait() except errors.InvalidArgumentError as e: raise errors.NotFoundError( None, None, "Initialization failed, no valid TPUs found. " + str(e)) except errors.InternalError as e: logging.error("Hit internal error during TPU system initialization. " + "It is likely hareware failure. \nPlease check the error " + "messages above to see whether that's the case. \nIf so, " + "consider to restart the job or try another machine.") raise e # Optionally exchange heartbeats between workers every minute. if in_multi_client_mode and api.heartbeat_enabled(): logging.info( "Starting DTensor heartbeat service exchanging signals every 10 minutes" ) heartbeat.start(period=180) # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access