Ejemplo n.º 1
0
def sharded_save(
    mesh: layout_lib.Mesh,
    file_prefix: Union[str, ops.Tensor],
    tensor_names: Union[List[str], ops.Tensor],
    shape_and_slices: Union[List[str], ops.Tensor],
    tensors: List[Union[ops.Tensor, tf_variables.Variable]],
):
    """Saves given named tensor slices in a sharded, multi-client safe fashion.

  The method makes sure the checkpoint directory state is correct in a sharded
  mutli-client saving. Namely, we place a barrier after SaveV2 to make sure
  every client has done writing the files. And another one after
  MergeV2Checkpoints to make sure all Metadata is properly merged.

  Upon existing, the checkpoint is completed and the all directory operations
  are done.

  Args:
    mesh: The Mesh that contains the Tensors to save.
    file_prefix: The prefix of checkpoint.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A MergeV2Checkpoints op that merged all Metadata.
  """
    with ops.device(api.device_name()):
        io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors)

    # Query generated shards and generate MergeV2.
    generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix],
                                      tensor_names, shape_and_slices, tensors)
    # api.py is still visible to external users but the _global_barrier() isn't
    # intended for public usage.
    # Once we locked down api.py visibility, we shall be able to make the `_`
    # prefix on these APIs go away.

    # Make sure all clients have written the files
    _global_barrier(mesh.host_mesh(), 'SaveV2')  # pylint: disable=protected-access

    with ops.device(api.device_name()):
        merge_op = io_ops.MergeV2Checkpoints(
            checkpoint_prefixes=generated_shards,
            destination_prefix=file_prefix,
            delete_old_dirs=True)

    # Make sure first device in first host has finished merge.
    # pylint: disable=protected-access
    _global_barrier(mesh.host_mesh(), 'MergeV2Checkpoints')
    # pylint: enable=protected-access

    return merge_op
Ejemplo n.º 2
0
    def save(
        self,
        file_prefix: str,
        options: Optional[checkpoint_options.CheckpointOptions] = None
    ) -> Optional[ops.Operation]:
        """Saves the saveable objects to a checkpoint with `file_prefix`.

    Also query the generated shards from the distributed DTensor SaveV2 ops and
    do a MergeV2 on those. Each op here is backed by a global_barrier to avoid
    racing from multiple clients.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object. This is unused in DTensor.

    Returns:
      An `Operation`, or None when executing eagerly.
    """
        if options is not None and options.experimental_io_device is not None:
            raise ValueError(
                "Specified experimental_io_device in DTensor checkpoint is not supported."
            )
        del options
        tensor_names = []
        tensors = []
        tensor_slices = []
        for saveable in self._saveable_objects:
            for spec in saveable.specs:
                tensor = spec.tensor
                # A tensor value of `None` indicates that this SaveableObject gets
                # recorded in the object graph, but that no value is saved in the
                # checkpoint.
                if tensor is not None:
                    if api.device_name() != spec.device:
                        # Some small tensors are placed on CPU0 from save manager and
                        # broadcasted to DTensor mesh, e,g., SaveCounter.
                        tensor = api.pack(
                            [tensor] *
                            self._mesh.host_mesh().num_local_devices(),
                            layout.Layout.replicated(self._mesh.host_mesh(),
                                                     rank=tensor.shape.rank))
                    tensor_names.append(spec.name)
                    tensors.append(tensor)
                    tensor_slices.append(spec.slice_spec)
        return save_restore.sharded_save(self._mesh, file_prefix, tensor_names,
                                         tensor_slices, tensors)
Ejemplo n.º 3
0
    def restore(
        self,
        file_prefix: str,
        options: Optional[checkpoint_options.CheckpointOptions] = None
    ) -> Dict[str, ops.Operation]:
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object. This is unused in DTensor.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
        if options is not None and options.experimental_io_device is not None:
            raise ValueError(
                "Specified experimental_io_device in DTensor checkpoint is not "
                "supported.")
        del options
        restore_specs = []
        tensor_structure = []
        for saveable in self._saveable_objects:
            saveable_tensor_structure = []
            tensor_structure.append(saveable_tensor_structure)
            # DTensor change 1 : Gather shapes and layout from original saveable
            # specs.
            # Note that this relies on the fact that the variables are already
            # initialized -- which isn't the behavior we want eventually.
            # TODO(b/159035705): Handle the variable initialization in restore.
            for spec in saveable.specs:
                saveable_tensor_structure.append(spec.name)
                if isinstance(spec, d_variable.DSaveSpec):
                    restore_specs.append(
                        (spec.name, spec.slice_spec, spec.dtype, spec.layout,
                         spec.global_shape))
                # Fall back to replicated layouts for non-DTensor saves that constructs
                # normal SaveSpec.
                elif isinstance(spec, saveable_object.SaveSpec):
                    restore_specs.append(
                        (spec.name, spec.slice_spec, spec.dtype,
                         layout.Layout.replicated(
                             self._mesh.host_mesh(),
                             spec.tensor.shape.rank).to_string(),
                         spec.tensor.shape.as_list()))
        tensor_names, tensor_slices, tensor_dtypes, layouts, global_shapes = zip(
            *restore_specs)
        with ops.device(api.device_name()):
            # DTensor change 2 : Run on customized DTensor RestoreV2 op rather than
            # stock TF io_ops.RestoreV2.
            restored_tensors = gen_dtensor_ops.d_tensor_restore_v2(
                prefix=file_prefix,
                tensor_names=tensor_names,
                shape_and_slices=tensor_slices,
                input_shapes=global_shapes,
                input_layouts=layouts,
                dtypes=tensor_dtypes)
        structured_restored_tensors = nest.pack_sequence_as(
            tensor_structure, restored_tensors)
        restore_ops = {}
        for saveable, restored_tensors in zip(self._saveable_objects,
                                              structured_restored_tensors):
            restore_ops[saveable.name] = saveable.restore(restored_tensors,
                                                          restored_shapes=None)
        return restore_ops
Ejemplo n.º 4
0
def name_based_restore(
    mesh: layout_lib.Mesh,
    checkpoint_prefix: str,
    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
):
    """Restores from checkpoint_prefix to name based DTensors.

  It is required to have already-initialized DTensor variables that have same
  shape/dtype for the tensors being restored.

  Also, we currently only support a named based restore on a single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.

  Returns:
    A dictionary of name to its restored DTensor value.
  """
    if not context.executing_eagerly():
        raise ValueError('name based restore must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Make sure that all tensors are on CPU mesh for now.
    # This might not be a hard limitation in the future.
    for name, tensor in ordered_name_tensor_dict.items():
        try:
            if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
                raise ValueError(
                    'Restoring a non CPU Tensor is not supported currently. Offending '
                    'tensor name : {tensor_name}'.format(tensor_name=name))
        except errors_impl.OpError as op_error:
            raise ValueError(
                'Saving/Restoring tensor must be a DTensor') from op_error

    # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    # Explicitly pack to mesh to avoid implicit small constant extraction, which
    # does not work larger restores that has lots of names.
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))
    shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] *
                                mesh.num_local_devices(),
                                layout_lib.Layout.replicated(mesh.host_mesh(),
                                                             rank=1))
    # A list of TensorShape representing all shapes for the input tensors.
    input_shapes = [
        tensor.shape for tensor in ordered_name_tensor_dict.values()
    ]
    input_layouts = [
        api.fetch_layout(tensor).to_string()
        for tensor in ordered_name_tensor_dict.values()
    ]

    with ops.device(api.device_name()):
        restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
            prefix=checkpoint_prefix,
            tensor_names=tensor_names,
            shape_and_slices=shape_and_slices,
            input_shapes=input_shapes,
            input_layouts=input_layouts,
            dtypes=[
                tensor.dtype for tensor in ordered_name_tensor_dict.values()
            ])

    return collections.OrderedDict(
        zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))