示例#1
0
def sharded_prefix(
    mesh: layout_lib.Mesh,
    prefix: List[str],
    tensor_names: List[str],
    shape_and_slices: List[str],
    tensors: List[ops.Tensor],
):
    """Generates all sharded prefix in distributed Save.

  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
  and it is desired to not save with same shard_prefix so that content will
  not be overwritten.

  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
  defines a unique set of shard prefix that is generated for all the Save ops.
  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
  is used in save.

  Args:
    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
      is used in Save op.
    prefix: The prefix of saving files.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A one d string tensor that represents all shard_prefix generated.
  """
    layout_str = array_ops.stack(
        [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
    layouts = api.pack([layout_str] * mesh.num_local_devices(),
                       layout_lib.Layout.replicated(mesh, rank=1))

    mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
                               layout_lib.Layout.replicated(mesh, rank=0))
    return gen_dtensor_ops.d_tensor_sharded_prefix(prefix,
                                                   tensor_names,
                                                   shape_and_slices,
                                                   mesh_str_tensor,
                                                   layouts=layouts,
                                                   tensors=tensors)
示例#2
0
def name_based_save(mesh: layout_lib.Mesh,
                    checkpoint_prefix: Union[str, ops.Tensor],
                    name_tensor_dict: Dict[str, Union[ops.Tensor,
                                                      tf_variables.Variable]]):
    """Saves name based Tensor into a Checkpoint.

  The function prepares the input dictionary to the format of a `sharded_save`,
  so that it can take advantage of DTensor SPMD based distributed save.

  Same as restore, the function only supports saving on the single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.
  """
    if not context.executing_eagerly():
        raise ValueError('name based save must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Current _dtensor_device() in api.py is the correct way of specifying
    # DTensor device singletons. The API itself will be eventually be moved to
    # a public API and provides global singleton in DTensor context.
    # For now, we just use the current `internal` API and aim at migrating in
    # one shot later.
    # TODO(hthu): Provide _dtensor_device() singleton as a public API.
    # pylint: disable=protected-access
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))

    sharded_save(mesh,
                 file_prefix=checkpoint_prefix,
                 tensor_names=tensor_names,
                 shape_and_slices=[''] * len(ordered_name_tensor_dict),
                 tensors=list(ordered_name_tensor_dict.values()))
示例#3
0
 def get_host_dvariable():
   # Copy to host mesh if needed.
   if original_layout.mesh.device_type().upper() != 'CPU':
     with ops.device(dvariable.device):
       host_dvariable = DVariable(
           api.pack(api.unpack(dvariable.read_value()), host_layout))
   else:
     host_dvariable = dvariable
   return (math_ops.cast(host_dvariable, dtypes.bfloat16)
           if self.should_cast(host_dvariable) else host_dvariable)
示例#4
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
示例#5
0
    def save(
        self,
        file_prefix: str,
        options: Optional[checkpoint_options.CheckpointOptions] = None
    ) -> Optional[ops.Operation]:
        """Saves the saveable objects to a checkpoint with `file_prefix`.

    Also query the generated shards from the distributed DTensor SaveV2 ops and
    do a MergeV2 on those. Each op here is backed by a global_barrier to avoid
    racing from multiple clients.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object. This is unused in DTensor.

    Returns:
      An `Operation`, or None when executing eagerly.
    """
        if options is not None and options.experimental_io_device is not None:
            raise ValueError(
                "Specified experimental_io_device in DTensor checkpoint is not supported."
            )
        del options
        tensor_names = []
        tensors = []
        tensor_slices = []
        for saveable in self._saveable_objects:
            for spec in saveable.specs:
                tensor = spec.tensor
                # A tensor value of `None` indicates that this SaveableObject gets
                # recorded in the object graph, but that no value is saved in the
                # checkpoint.
                if tensor is not None:
                    if api.device_name() != spec.device:
                        # Some small tensors are placed on CPU0 from save manager and
                        # broadcasted to DTensor mesh, e,g., SaveCounter.
                        tensor = api.pack(
                            [tensor] *
                            self._mesh.host_mesh().num_local_devices(),
                            layout.Layout.replicated(self._mesh.host_mesh(),
                                                     rank=tensor.shape.rank))
                    tensor_names.append(spec.name)
                    tensors.append(tensor)
                    tensor_slices.append(spec.slice_spec)
        return save_restore.sharded_save(self._mesh, file_prefix, tensor_names,
                                         tensor_slices, tensors)
示例#6
0
    def get_next(self):
        """Returns the next element.

    Returns:
      A possibly nested structure of values matching
      `tf.data.Iterator.element_spec`.

    Raises:
      `tf.errors.OutOfRangeError`: if the end of the underlying iterators has
        been reached.
      RuntimeError: if any of the underlying iterators do not return the
        expected number of items.
    """
        # Create the data structure to store the individual elements of the current
        # batch. We store a list per element in the flattened dataset batch, and
        # each list should contain as many tensors as there local devices.
        curr_batch_elems = [[] for _ in range(len(self._flattened_layouts))]

        for _, iterator in self._iterators:
            for _ in range(self._num_local_devices_per_replica):
                element = iterator.get_next()

                # Separate the dataset elements based on the structure of the dataset.
                flattened_element = nest.flatten(element)
                for idx, batch in enumerate(flattened_element):
                    curr_batch_elems[idx].append(batch)

        flattened_output = []
        for batch_elems, layout in zip(curr_batch_elems,
                                       self._flattened_layouts):
            expected_num_elems = layout.mesh.num_local_devices()
            actual_num_elems = len(batch_elems)
            if actual_num_elems != expected_num_elems:
                raise RuntimeError(
                    'Expected to pack %d elements in batch but got %d' %
                    (expected_num_elems, actual_num_elems))
            flattened_output.append(api.pack(batch_elems, layout))
        return nest.pack_sequence_as(self._layouts, flattened_output)
示例#7
0
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status
示例#8
0
def name_based_restore(
    mesh: layout_lib.Mesh,
    checkpoint_prefix: str,
    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
):
    """Restores from checkpoint_prefix to name based DTensors.

  It is required to have already-initialized DTensor variables that have same
  shape/dtype for the tensors being restored.

  Also, we currently only support a named based restore on a single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.

  Returns:
    A dictionary of name to its restored DTensor value.
  """
    if not context.executing_eagerly():
        raise ValueError('name based restore must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Make sure that all tensors are on CPU mesh for now.
    # This might not be a hard limitation in the future.
    for name, tensor in ordered_name_tensor_dict.items():
        try:
            if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
                raise ValueError(
                    'Restoring a non CPU Tensor is not supported currently. Offending '
                    'tensor name : {tensor_name}'.format(tensor_name=name))
        except errors_impl.OpError as op_error:
            raise ValueError(
                'Saving/Restoring tensor must be a DTensor') from op_error

    # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    # Explicitly pack to mesh to avoid implicit small constant extraction, which
    # does not work larger restores that has lots of names.
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))
    shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] *
                                mesh.num_local_devices(),
                                layout_lib.Layout.replicated(mesh.host_mesh(),
                                                             rank=1))
    # A list of TensorShape representing all shapes for the input tensors.
    input_shapes = [
        tensor.shape for tensor in ordered_name_tensor_dict.values()
    ]
    input_layouts = [
        api.fetch_layout(tensor).to_string()
        for tensor in ordered_name_tensor_dict.values()
    ]

    with ops.device(api.device_name()):
        restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
            prefix=checkpoint_prefix,
            tensor_names=tensor_names,
            shape_and_slices=shape_and_slices,
            input_shapes=input_shapes,
            input_layouts=input_layouts,
            dtypes=[
                tensor.dtype for tensor in ordered_name_tensor_dict.values()
            ])

    return collections.OrderedDict(
        zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
示例#9
0
 def pack(tensors, layout):
     with ops.device(dvariable.device):
         return api.pack(tensors, layout)