def setUp(self): super(OptimizersTest, self).setUp() global_ids = test_util.create_device_ids_array((2, 2)) local_device_ids = np.ravel(global_ids).tolist() mesh_dict = { 'CPU': dtensor.Mesh(['X', 'Y'], global_ids, local_device_ids, test_util.create_device_list((2, 2), 'CPU')) } self.mesh = self.configTestMesh(mesh_dict)
def setUp(self): super(LayersTest, self).setUp() tf_utils.set_random_seed(1337) global_ids = test_util.create_device_ids_array((2, 2)) local_device_ids = np.ravel(global_ids).tolist() mesh_dict = { 'CPU': dtensor.Mesh(['X', 'Y'], global_ids, local_device_ids, test_util.create_device_list((2, 2), 'CPU')) } self.mesh = self.configTestMesh(mesh_dict) self.layout_4d = dtensor.Layout.replicated(self.mesh, rank=4) self.layout_3d = dtensor.Layout.replicated(self.mesh, rank=3) self.layout_2d = dtensor.Layout.replicated(self.mesh, rank=2) self.layout_1d = dtensor.Layout.replicated(self.mesh, rank=1)
def setUp(self): super(LayoutMapTest, self).setUp() backend.enable_tf_random_generator() tf_utils.set_random_seed(1337) global_ids = test_util.create_device_ids_array((2, 2)) local_device_ids = np.ravel(global_ids).tolist() mesh_dict = { 'CPU': dtensor.Mesh(['X', 'Y'], global_ids, local_device_ids, test_util.create_device_list((2, 2), 'CPU')) } self.mesh = self.configTestMesh(mesh_dict) self.layout_2d = dtensor.Layout.replicated(self.mesh, rank=2) self.layout_1d = dtensor.Layout.replicated(self.mesh, rank=1) self.sharded_2d = dtensor.Layout.batch_sharded(self.mesh, 'X', rank=2) self.sharded_1d = dtensor.Layout.batch_sharded(self.mesh, 'X', rank=1)
def create_mesh(mesh_dims: Optional[List[Tuple[str, int]]] = None, mesh_name: str = '', devices: Optional[List[str]] = None, device_type: Optional[str] = None) -> dtensor.Mesh: """Creates a single-client mesh. If both `mesh_dims` and `devices` are specified, they must match each otehr. As a special case, when all arguments are missing, this creates a 1D CPU mesh with an empty name, assigning all available devices to that dimension. Args: mesh_dims: A list of (dim_name, dim_size) tuples. Defaults to a single batch-parallel dimension called 'x' using all devices. As a special case, a single-element mesh_dims whose dim_size is -1 also uses all devices. mesh_name: Name of the created mesh. Defaults to ''. devices: String representations of devices to use. This is the device part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available logical devices. device_type: If `devices` is missing, the type of devices to use. Defaults to 'CPU'. Returns: A single-client mesh created from specified or default arguments. """ if devices is None: if device_type is None: device_type = 'CPU' devices = [ tf_device.DeviceSpec.from_string(d.name) for d in tf_config.list_logical_devices(device_type) ] else: devices = [ tf_device.DeviceSpec.from_string( '/job:localhost/replica:0/task:0/' + d) for d in devices ] if device_type is None: device_type = devices[0].device_type if device_type.upper() != devices[0].device_type.upper(): raise ValueError( f'Conflicting devices {str(devices)} and device_type {device_type}' ) if mesh_dims is None: mesh_dims = [('x', len(devices))] elif len(mesh_dims) == 1 and mesh_dims[0][1] == -1: # Replace -1 dim_size in a 1D mesh will the number of all devices. mesh_dims[0] = (mesh_dims[0][0], len(devices)) dim_names = [d[0] for d in mesh_dims] shape = [d[1] for d in mesh_dims] global_device_ids = np.arange(len(devices)).reshape(shape) local_device_ids = np.ravel(global_device_ids).tolist() mesh = dtensor.Mesh(dim_names=dim_names, global_device_ids=global_device_ids, local_device_ids=local_device_ids, local_devices=devices, mesh_name=mesh_name) _print_context(num_global_devices=len(devices), num_clients=1, client_id=0, device_type=devices[0].device_type, mesh=mesh) return mesh
def create_distributed_mesh(mesh_dims: List[Tuple[str, int]], mesh_name: str = '', num_global_devices: Optional[int] = None, num_clients: Optional[int] = None, client_id: Optional[int] = None, device_type: str = 'CPU') -> dtensor.Mesh: """Creates a single- or multi-client mesh. For CPU and GPU meshes, users can choose to use fewer local devices than what is available. If any argument is missing, it will be extracted from environment variables. The default values for these environment variables create a single-client mesh using all devices (common for unit tests). For TPU meshes, users should not specify any of the nullable arguments. The DTensor runtime will set these arguments automatically, using all TPU cores available in the entire cluster. Args: mesh_dims: A list of (dim_name, dim_size) tuples. mesh_name: Name of the created mesh. Defaults to ''. num_global_devices: Number of devices in the DTensor cluster. Defaults to the corresponding environment variable. num_clients: Number of clients in the DTensor cluster. Defaults to the corresponding environment variable. client_id: This client's ID. Defaults to the corresponding environment variable. device_type: Type of device to build the mesh for. Defaults to 'CPU'. Returns: A single-client mesh created from specified or default arguments. """ if device_type.upper() in ['CPU', 'GPU']: # For CPU and GPU meshes, user-specified args take precedence over env vars. # This is particularly useful on single clients when users want to create # meshes that use fewer logical devices than what's available. if num_global_devices is None: num_global_devices = dtensor.num_global_devices(device_type) if num_global_devices <= 0: raise ValueError( f'num_global_devices ({num_global_devices}) must be > 0') if num_clients is None: num_clients = dtensor.num_clients() if num_clients <= 0: raise ValueError(f'num_clients ({num_clients}) must be > 0') if client_id is None: client_id = dtensor.client_id() if client_id < 0: raise ValueError(f'client_id ({client_id}) must be >= 0') if client_id >= num_clients: raise ValueError( f'client_id ({client_id}) must be < {num_clients}') if num_global_devices % num_clients != 0: raise ValueError( f'num_global_devices ({num_global_devices}) must be ' f'divisible by num_clients ({num_clients})') num_local_devices = num_global_devices // num_clients # It's allowed to create a CPU or GPU mesh using fewer logical devices than # what's available. If so, just use the first N logical devices. num_available_devices = dtensor.num_local_devices(device_type) if num_local_devices > num_available_devices: raise ValueError( f'Not enough devices; {num_local_devices} needed, ' f'only {num_available_devices} available') local_devices = dtensor.local_devices(device_type, client_id)[:num_local_devices] dim_names = [d[0] for d in mesh_dims] shape = [d[1] for d in mesh_dims] global_device_ids = np.arange(num_global_devices).reshape(shape) flattened = np.ravel(global_device_ids).tolist() start_idx = num_local_devices * client_id local_device_ids = flattened[start_idx:start_idx + num_local_devices] mesh = dtensor.Mesh(dim_names=dim_names, global_device_ids=global_device_ids, local_device_ids=local_device_ids, local_devices=local_devices, mesh_name=mesh_name) _print_context(num_global_devices, num_clients, client_id, device_type, mesh) return mesh if device_type.upper() == 'TPU': # TPU meshes can only be configured through environment variables that # reflect the actual TPU topology. Do not let users specify custom args. if num_global_devices is not None: raise ValueError( f'Do not specify num_global_devices for {device_type.upper()} meshes. ' 'It will be filled in automatically from environmental variables.' 'See api.py for the list of environmental variables for DTensor.' ) if num_clients is not None: raise ValueError( f'Do not specify num_clients for {device_type.upper()} meshes. ' 'It will be filled in automatically from environmental variables.' 'See api.py for the list of environmental variables for DTensor.' ) if client_id is not None: raise ValueError( f'Do not specify client_id for {device_type.upper()} meshes. ' 'It will be filled in automatically from environmental variables.' 'See api.py for the list of environmental variables for DTensor.' ) dim_names = [mesh_dim[0] for mesh_dim in mesh_dims] shape = [mesh_dim[1] for mesh_dim in mesh_dims] mesh = tpu_util.create_tpu_mesh(dim_names, shape, mesh_name) _print_context(dtensor.num_global_devices(device_type), dtensor.num_clients(), dtensor.client_id(), device_type, mesh) return mesh raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
def create_tpu_mesh(mesh_dim_names: List[str], mesh_shape: List[int], mesh_name: str, ring_dims: Optional[int] = None, ring_axes: Optional[List[str]] = None, ring_bounds: Optional[List[int]] = None, can_split_host_across_rings: bool = True, build_ring_across_rings: bool = False, rotate_ring_across_rings: bool = False) -> dtensor.Mesh: """Returns a TPU mesh optimized for AllReduce ring reductions. Only as many as leading axes specified by `ring_axes` as necessary will be used to build rings, as long as the subslice formed by these axes have enough cores to contain a ring of the required size. The leftover axes in `ring_axes` won't affect results. See go/dtensor-device-assignment-api for details and performance tuning tips. Args: mesh_dim_names: List of mesh dimension names. mesh_shape: Shape of the mesh. mesh_name: A unique name for the mesh. If empty, internally generate one. ring_dims: Optional; The number of leading (ring_dims > 0) or trailing (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build rings for all but the first dimension. ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying the order of TPU topology axes to build rings in. If unspecified, default to ["core", "x", "y", "z"]. ring_bounds: Optional; The maximum number of devices on each axis, in the x, y, z, core order. If unspecified, default to physical topology limits. can_split_host_across_rings: Optional; If true, devices attached to the same host (i.e., DTensor client) may get assigned to different rings. Setting it to false may cause some combinations of arguments to be infeasible; see DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples. build_ring_across_rings: Optional; If true, also build a data-parallel ring across model-parallel rings. This ring could be strided. rotate_ring_across_rings: Optional; If true, build the data-parallel ring in column-major instead of row-major order. """ logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape) logging.info("Requested ring_dims: %s", ring_dims) logging.info("Requested ring_axes: %s", ring_axes) logging.info("Requested ring_bounds: %s", ring_bounds) logging.info("Requested can_split_host_across_rings: %s", can_split_host_across_rings) if not mesh_name: mesh_name = "mesh_%f" % time.time() logging.info("Requested mesh_name: %s", mesh_name) # By default, build rings for all but the first (usually batch) dimension. if ring_dims is None: ring_dims = 1 - len(mesh_shape) elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape): raise ValueError("Invalid ring_dims value: %d" % ring_dims) logging.info("Actual ring_dims: %s", ring_dims) # By default, vary axes in the core -> x -> y -> z order. if ring_axes is None: ring_axes = ["core", "x", "y", "z"] elif len(ring_axes) != 4: raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes) elif sorted(ring_axes) != ["core", "x", "y", "z"]: raise ValueError("Invalid ring_axes value: %s" % ring_axes) logging.info("Actual ring_axes: %s", ring_axes) # Validate ring_bounds values. global _tpu_topology if _tpu_topology is None: raise ValueError( "Invalid TPU topology, run dtensor_initialize_tpu_system() first") topology_shape = list(_tpu_topology.mesh_shape) if ring_bounds is None: ring_bounds = topology_shape elif len(ring_bounds) != 4: raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds) elif ring_bounds > topology_shape: raise ValueError("ring_bounds %s should be <= topology sizes %s" % (ring_bounds, topology_shape)) logging.info("Actual ring_bounds: %s", ring_bounds) # Compute ring_size, the number of cores in a ring. if ring_dims > 0: ring_size = np.prod(mesh_shape[:ring_dims]) elif ring_dims < 0: ring_size = np.prod(mesh_shape[ring_dims:]) else: ring_size = 1 # single-core rings logging.info("Actual ring_size: %d", ring_size) # Rearrange all cores according to the axis iteration order. global_core_locations = _enumerate_core_locations( topology_shape, ring_bounds, ring_axes, can_split_host_across_rings, ring_size) logging.vlog(1, "Enumerated core locations: %s", global_core_locations) num_cores = len(global_core_locations) # The mesh to be created must use all TPU cores in the system. mesh_size = np.prod(mesh_shape) if mesh_size != num_cores: raise ValueError( "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" % (mesh_shape, num_cores)) # Build a ring for the `ring_size` dimension and, if required, a strided ring # for the orthogonal dimension. if build_ring_across_rings: global_core_locations = _build_orthogonal_rings( global_core_locations, ring_size, rotate_ring_across_rings) else: permutation = _build_all_reduce_ring(global_core_locations[:ring_size]) for r in range(0, num_cores, ring_size): global_core_locations[r:r + ring_size] = [ global_core_locations[r + permutation[i]] for i in range(ring_size) ] logging.vlog(1, "Permutated core locations: %s", global_core_locations) # For this point on, change from List[CoreLocation] to List[List[int]] for # easier interaction with the C++ API. global_core_locations = [l.to_list() for l in global_core_locations] global _dtensor_device if _dtensor_device is None: raise ValueError( "Invalid system device, run dtensor_initialize_tpu_system() first") global_core_ids = _dtensor_device.tpu_core_locations_to_ids( global_core_locations) # Store a per-mesh mapping in the runtime. _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids) # Create the mesh by manually specifying local_device_ids. local_core_locations = _tpu_topology.device_coordinates[ dtensor.client_id()] indexes = [ global_core_locations.index(list(local_core_location)) for local_core_location in local_core_locations ] global_device_ids, local_device_ids, local_device_list = _create_device_array( mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes) return dtensor.Mesh(mesh_dim_names, global_device_ids, local_device_ids, local_device_list, mesh_name)
def dtensor_initialize_tpu_system(enable_coordination_service=False): """Initialize the TPU devices. Args: enable_coordination_service: If true, enable distributed coordination service to make sure that workers know the devices on each other, a prerequisite for data transfer through cross-worker rendezvous. Raises: RuntimeError: If running inside a tf.function. NotFoundError: If no TPU devices found in eager mode. """ assert context.executing_eagerly() in_multi_client_mode = dtensor.job_name() != "localhost" # Collective GRPC servers are only necessary in mutli-client setup. # Single clients (e.g. Forge) can use local mode of collectives. if in_multi_client_mode: if dtensor.jobs() is None: raise ValueError( "DTENSOR_JOBS environment variable is required when" "using multi-client to properly set up communications between servers" ) multi_client_util.initialize_multi_client_cluster( job_name=dtensor.job_name(), dtensor_jobs=dtensor.jobs(), client_id=dtensor.client_id(), collective_leader=dtensor.full_job_name(task_id=0), enable_coordination_service=enable_coordination_service) # Make sure the server change is fully propagated before attempting to run # the core ID merging logic below. context.ensure_initialized() context.async_wait() context.context()._clear_caches() # pylint: disable=protected-access @function.defun def _tpu_init_fn(): return dtensor.ops.configure_and_initialize_global_tpu() try: with ops.device("/job:" + dtensor.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access my_core_ids = _tpu_init_fn() logging.info("TPU core IDs: %s", my_core_ids) context.initialize_logical_devices() # Configure virtual CPUs that is 1:1 mapped to TPU cores. context.context().set_logical_cpu_devices( len(dtensor.local_devices(_TPU_DEVICE_TYPE)), tf_device.DeviceSpec(job=dtensor.job_name(), replica=0, task=dtensor.client_id()).to_string()) # `my_core_ids` contains the IDs of TPU cores attached to this host. # # To generate correct and efficient XLA AllReduce group assignment, we must # merge these arrays from all hosts and broadcast the result back to all # hosts, so all hosts can use these mappings in their MLIR passes. # # This is essentially doing what WaitForDistributedTpuOp and # SetGlobalTPUArrayOp do, in our multi-client environment. task_id = dtensor.client_id() num_tasks = dtensor.num_clients() num_devices = dtensor.num_global_devices(_TPU_DEVICE_TYPE) num_devices_per_task = int(num_devices / num_tasks) # Create a one-time use mesh and layout just for merging core IDs. mesh = dtensor.Mesh([_MESH_DIM_X], *_create_device_array((num_devices, ), _TPU_DEVICE_TYPE, dtensor.client_id())) layout = dtensor.Layout([_MESH_DIM_X, dtensor.UNSHARDED], mesh) device = dtensor_device.DTensorDevice(meshes=[mesh]) logging.info("TPU core locations: %s", device.tpu_core_ids_to_locations(my_core_ids)) # At this point, we don't know which cores are attached to other hosts. # The core ID mappings in the runtime haven't been set yet. # # The core ID merging AllReduce below is carefully written so it works # without needing correct core mappings to be set in the runtime. We will # use this AllReduce's result to set the core ID mappings, and all future # user-initiated AllReduces will use the mappings. # # The runtime is hard-coded to ignore core ID mappings on this AllReduce. all_core_ids = np.zeros([num_devices], dtype=np.int32) for i in range(len(my_core_ids)): all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i] # Only one local device gets valid input: 8 local core IDs among # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset. # The other 7 local devices get zero inputs. All devices on all host # participate in one AllReduce, whose result will be core IDs arranged by # task-device ordinals. all_core_ids = constant_op.constant([all_core_ids]) zeros = array_ops.zeros_like(all_core_ids) all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1) with ops.device(device.name): all_core_ids = device.pack(all_core_ids, layout) all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0]) unpacked_all_tpu_ids = device.unpack(all_core_ids) all_core_ids = list(unpacked_all_tpu_ids[0].numpy()) logging.info("All TPU core IDs: %s", all_core_ids) # Set the default core ID mappings in the runtime for legacy code and tests. # # Legacy code and tests create TPU meshes directly without using the # `create_tpu_mesh` function below. Those meshes have global device IDs # equal to TF task-device ordinals. The `all_core_ids` array happens to # arrange core IDs by TF task-device ordinals. Using this array on those # meshes guarantee correct although inefficient results. device.set_tpu_core_ids("", all_core_ids) # Remember enough global, immutable information to be able to build any ring # we want prescribed by `create_tpu_mesh` in the future. global _all_core_ids _all_core_ids = all_core_ids all_core_locations = device.tpu_core_ids_to_locations(all_core_ids) all_core_locations = [ _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations ] global _all_core_locations _all_core_locations = all_core_locations logging.info("All TPU core locations: %s", all_core_locations) tpu_topology = _create_tpu_topology(all_core_locations, num_tasks, num_devices_per_task) global _tpu_topology _tpu_topology = tpu_topology logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape, tpu_topology.device_coordinates) global _dtensor_device _dtensor_device = device context.async_wait() except errors.InvalidArgumentError as e: raise errors.NotFoundError( None, None, "Initialization failed, no valid TPUs found. " + str(e)) except errors.InternalError as e: logging.error( "Hit internal error during TPU system initialization. " + "It is likely hareware failure. \nPlease check the error " + "messages above to see whether that's the case. \nIf so, " + "consider to restart the job or try another machine.") raise e # Optionally exchange heartbeats between workers every minute. if in_multi_client_mode and dtensor.heartbeat_enabled(): logging.info( "Starting DTensor heartbeat service exchanging signals every 10 minutes" ) heartbeat.start(period=180) # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access