def _register_mesh(self, mesh: layout_lib.Mesh): """Idempotently register `mesh` with the dtensor device.""" with self._mesh_lock: if mesh not in self._meshes: _pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(), self._is_async, False) self._meshes.add(mesh) if mesh.device_type().upper() == "TPU": logging.info( "Registering virtual 1:1 mapped host mesh %s for mesh %s", mesh.host_mesh().to_string(), mesh.to_string()) _pywrap_dtensor_device.AddMesh( self._device_info, mesh.host_mesh().to_string(), self._is_async, True) self._meshes.add(mesh.host_mesh()) embedding_host_mesh = self._create_embedding_host_mesh( mesh) if embedding_host_mesh: logging.info( "Registering embedding host mesh %s on each client for mesh %s", embedding_host_mesh.to_string(), mesh.to_string()) _pywrap_dtensor_device.AddMesh( self._device_info, embedding_host_mesh.to_string(), self._is_async, False) self._meshes.add(embedding_host_mesh)
def _experimental_default_mesh(self, mesh: layout_lib.Mesh): """Sets a default mesh for all ops in the scope. Note: This is an internal helper method, which is not user facing api. Useful for requesting a specific mesh for ops which would have no inferred layout, e.g. tf.zeros. Args: mesh: A Mesh to be used for ops without Mesh. Yields: Nothing. """ previous_default = self._current_default_mesh self._register_mesh(mesh) _pywrap_dtensor_device.ExperimentalSetDefaultMesh( self._device_info, mesh.to_string().encode("utf-8")) self._current_default_mesh = mesh yield _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info) if previous_default: _pywrap_dtensor_device.ExperimentalSetDefaultMesh( self._device_info, previous_default.to_string().encode("utf-8")) self._current_default_mesh = previous_default
def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh): """Returns Embedding host mesh for each client.""" if tpu_mesh.device_type().upper() != "TPU": raise ValueError("Must pass input of a tpu mesh.") # Global device ids are global host ids, while local device ids contains # local host id. ts_local_device_ids = [] ts_local_devices = [] for local_device_str in tpu_mesh.local_devices(): # We only need to keep TPU:0 for each client. if not local_device_str.endswith("TPU:0"): continue device_spec = tf_device.DeviceSpec.from_string(local_device_str) ts_local_device_ids.append(device_spec.task) ts_local_devices.append(device_spec.replace(device_type="CPU")) if not ts_local_device_ids or not ts_local_device_ids: logging.info( "Cannot create tpu system mesh as %s has no `TPU:0` local device " "found", tpu_mesh.to_string()) return None ts_global_device_ids = np.arange(self._num_clients()) # TODO(zhonglinhan): parse global device specs as input when not None. return layout_lib.Mesh( dim_names=[tpu_mesh.dim_names[0]], # 1D mesh. global_device_ids=ts_global_device_ids, local_device_ids=ts_local_device_ids, local_devices=ts_local_devices)
def _experimental_default_mesh(self, mesh: layout_lib.Mesh): self._register_mesh(mesh) _dtensor_device.ExperimentalSetDefaultMesh( self._device_info, mesh.to_string().encode("utf-8")) yield _dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
def sharded_prefix( mesh: layout_lib.Mesh, prefix: List[str], tensor_names: List[str], shape_and_slices: List[str], tensors: List[ops.Tensor], ): """Generates all sharded prefix in distributed Save. DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices, and it is desired to not save with same shard_prefix so that content will not be overwritten. (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively defines a unique set of shard prefix that is generated for all the Save ops. Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what is used in save. Args: mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what is used in Save op. prefix: The prefix of saving files. tensor_names: a list of tensor names used in save op. shape_and_slices: a list of shape and slice specification used in save op. The only supported value is "" as we don't support distributed saving with slices yet. tensors: a list of tensors used in save op. The order should match tensor_names. Returns: A one d string tensor that represents all shard_prefix generated. """ layout_str = array_ops.stack( [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0) layouts = api.pack([layout_str] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=1)) mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh, rank=0)) return gen_dtensor_ops.d_tensor_sharded_prefix(prefix, tensor_names, shape_and_slices, mesh_str_tensor, layouts=layouts, tensors=tensors)