def _register_mesh(self, mesh: layout_lib.Mesh): """Idempotently register `mesh` with the dtensor device.""" with self._mesh_lock: if mesh not in self._meshes: _pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(), self._is_async, False) self._meshes.add(mesh) if mesh.device_type().upper() == "TPU": logging.info( "Registering virtual 1:1 mapped host mesh %s for mesh %s", mesh.host_mesh().to_string(), mesh.to_string()) _pywrap_dtensor_device.AddMesh( self._device_info, mesh.host_mesh().to_string(), self._is_async, True) self._meshes.add(mesh.host_mesh()) embedding_host_mesh = self._create_embedding_host_mesh( mesh) if embedding_host_mesh: logging.info( "Registering embedding host mesh %s on each client for mesh %s", embedding_host_mesh.to_string(), mesh.to_string()) _pywrap_dtensor_device.AddMesh( self._device_info, embedding_host_mesh.to_string(), self._is_async, False) self._meshes.add(embedding_host_mesh)
def sharded_save( mesh: layout_lib.Mesh, file_prefix: Union[str, ops.Tensor], tensor_names: Union[List[str], ops.Tensor], shape_and_slices: Union[List[str], ops.Tensor], tensors: List[Union[ops.Tensor, tf_variables.Variable]], ): """Saves given named tensor slices in a sharded, multi-client safe fashion. The method makes sure the checkpoint directory state is correct in a sharded mutli-client saving. Namely, we place a barrier after SaveV2 to make sure every client has done writing the files. And another one after MergeV2Checkpoints to make sure all Metadata is properly merged. Upon existing, the checkpoint is completed and the all directory operations are done. Args: mesh: The Mesh that contains the Tensors to save. file_prefix: The prefix of checkpoint. tensor_names: a list of tensor names used in save op. shape_and_slices: a list of shape and slice specification used in save op. The only supported value is "" as we don't support distributed saving with slices yet. tensors: a list of tensors used in save op. The order should match tensor_names. Returns: A MergeV2Checkpoints op that merged all Metadata. """ with ops.device(api.device_name()): io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors) # Query generated shards and generate MergeV2. generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix], tensor_names, shape_and_slices, tensors) # api.py is still visible to external users but the _global_barrier() isn't # intended for public usage. # Once we locked down api.py visibility, we shall be able to make the `_` # prefix on these APIs go away. # Make sure all clients have written the files _global_barrier(mesh.host_mesh(), 'SaveV2') # pylint: disable=protected-access with ops.device(api.device_name()): merge_op = io_ops.MergeV2Checkpoints( checkpoint_prefixes=generated_shards, destination_prefix=file_prefix, delete_old_dirs=True) # Make sure first device in first host has finished merge. # pylint: disable=protected-access _global_barrier(mesh.host_mesh(), 'MergeV2Checkpoints') # pylint: enable=protected-access return merge_op
def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, ops.Tensor], name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]]): """Saves name based Tensor into a Checkpoint. The function prepares the input dictionary to the format of a `sharded_save`, so that it can take advantage of DTensor SPMD based distributed save. Same as restore, the function only supports saving on the single mesh. Args: mesh: The single mesh that all Tensors would be restored to. checkpoint_prefix : The prefix of checkpoint to be restored. name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The DTensor shape/dtype must match the tensors being saved/restored for now. """ if not context.executing_eagerly(): raise ValueError('name based save must run eagerly.') ordered_name_tensor_dict = name_tensor_dict if not isinstance(name_tensor_dict, collections.OrderedDict): ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) # Current _dtensor_device() in api.py is the correct way of specifying # DTensor device singletons. The API itself will be eventually be moved to # a public API and provides global singleton in DTensor context. # For now, we just use the current `internal` API and aim at migrating in # one shot later. # TODO(hthu): Provide _dtensor_device() singleton as a public API. # pylint: disable=protected-access checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) sharded_save(mesh, file_prefix=checkpoint_prefix, tensor_names=tensor_names, shape_and_slices=[''] * len(ordered_name_tensor_dict), tensors=list(ordered_name_tensor_dict.values()))
def name_based_restore( mesh: layout_lib.Mesh, checkpoint_prefix: str, name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], ): """Restores from checkpoint_prefix to name based DTensors. It is required to have already-initialized DTensor variables that have same shape/dtype for the tensors being restored. Also, we currently only support a named based restore on a single mesh. Args: mesh: The single mesh that all Tensors would be restored to. checkpoint_prefix : The prefix of checkpoint to be restored. name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The DTensor shape/dtype must match the tensors being saved/restored for now. Returns: A dictionary of name to its restored DTensor value. """ if not context.executing_eagerly(): raise ValueError('name based restore must run eagerly.') ordered_name_tensor_dict = name_tensor_dict if not isinstance(name_tensor_dict, collections.OrderedDict): ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) # Make sure that all tensors are on CPU mesh for now. # This might not be a hard limitation in the future. for name, tensor in ordered_name_tensor_dict.items(): try: if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU': raise ValueError( 'Restoring a non CPU Tensor is not supported currently. Offending ' 'tensor name : {tensor_name}'.format(tensor_name=name)) except errors_impl.OpError as op_error: raise ValueError( 'Saving/Restoring tensor must be a DTensor') from op_error # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2. checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) # Explicitly pack to mesh to avoid implicit small constant extraction, which # does not work larger restores that has lots of names. tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(), layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) # A list of TensorShape representing all shapes for the input tensors. input_shapes = [ tensor.shape for tensor in ordered_name_tensor_dict.values() ] input_layouts = [ api.fetch_layout(tensor).to_string() for tensor in ordered_name_tensor_dict.values() ] with ops.device(api.device_name()): restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2( prefix=checkpoint_prefix, tensor_names=tensor_names, shape_and_slices=shape_and_slices, input_shapes=input_shapes, input_layouts=input_layouts, dtypes=[ tensor.dtype for tensor in ordered_name_tensor_dict.values() ]) return collections.OrderedDict( zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))