def sharded_prefix(
    mesh: layout_lib.Mesh,
    prefix: List[str],
    tensor_names: List[str],
    shape_and_slices: List[str],
    tensors: List[ops.Tensor],
):
    """Generates all sharded prefix in distributed Save.

  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
  and it is desired to not save with same shard_prefix so that content will
  not be overwritten.

  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
  defines a unique set of shard prefix that is generated for all the Save ops.
  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
  is used in save.

  Args:
    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
      is used in Save op.
    prefix: The prefix of saving files.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A one d string tensor that represents all shard_prefix generated.
  """
    layout_str = array_ops.stack(
        [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
    layouts = api.pack([layout_str] * mesh.num_local_devices(),
                       layout_lib.Layout.replicated(mesh, rank=1))

    mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
                               layout_lib.Layout.replicated(mesh, rank=0))
    return gen_dtensor_ops.d_tensor_sharded_prefix(prefix,
                                                   tensor_names,
                                                   shape_and_slices,
                                                   mesh_str_tensor,
                                                   layouts=layouts,
                                                   tensors=tensors)
def name_based_save(mesh: layout_lib.Mesh,
                    checkpoint_prefix: Union[str, ops.Tensor],
                    name_tensor_dict: Dict[str, Union[ops.Tensor,
                                                      tf_variables.Variable]]):
    """Saves name based Tensor into a Checkpoint.

  The function prepares the input dictionary to the format of a `sharded_save`,
  so that it can take advantage of DTensor SPMD based distributed save.

  Same as restore, the function only supports saving on the single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.
  """
    if not context.executing_eagerly():
        raise ValueError('name based save must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Current _dtensor_device() in api.py is the correct way of specifying
    # DTensor device singletons. The API itself will be eventually be moved to
    # a public API and provides global singleton in DTensor context.
    # For now, we just use the current `internal` API and aim at migrating in
    # one shot later.
    # TODO(hthu): Provide _dtensor_device() singleton as a public API.
    # pylint: disable=protected-access
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))

    sharded_save(mesh,
                 file_prefix=checkpoint_prefix,
                 tensor_names=tensor_names,
                 shape_and_slices=[''] * len(ordered_name_tensor_dict),
                 tensors=list(ordered_name_tensor_dict.values()))
Exemple #3
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
Exemple #4
0
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str):
    """Runs a global barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients.

  Currently we allocate a fully sharded tensor with mesh shape and run a
  all_reduce on it.

  Args:
    mesh: The mesh to run the global barrier on.
    last_op_name: The last op run before the global_barrier. mainly used for
      logging purpose.
  """
    logging.info('entering global barrier before op: %s', last_op_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    shape = api._dtensor_device().pack(  # pylint: disable=protected-access
        [mesh.shape()] * mesh.num_local_devices(),
        layout_lib.Layout.replicated(mesh, rank=1))
    ones = api.call_with_layout(array_ops.ones,
                                layout_lib.Layout(mesh.dim_names, mesh),
                                shape=shape,
                                dtype=dtypes.float32)
    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info(
        'finished running global barrier across all clients after '
        'op: %s', last_op_name)
def name_based_restore(
    mesh: layout_lib.Mesh,
    checkpoint_prefix: str,
    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
):
    """Restores from checkpoint_prefix to name based DTensors.

  It is required to have already-initialized DTensor variables that have same
  shape/dtype for the tensors being restored.

  Also, we currently only support a named based restore on a single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.

  Returns:
    A dictionary of name to its restored DTensor value.
  """
    if not context.executing_eagerly():
        raise ValueError('name based restore must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Make sure that all tensors are on CPU mesh for now.
    # This might not be a hard limitation in the future.
    for name, tensor in ordered_name_tensor_dict.items():
        try:
            if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
                raise ValueError(
                    'Restoring a non CPU Tensor is not supported currently. Offending '
                    'tensor name : {tensor_name}'.format(tensor_name=name))
        except errors_impl.OpError as op_error:
            raise ValueError(
                'Saving/Restoring tensor must be a DTensor') from op_error

    # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    # Explicitly pack to mesh to avoid implicit small constant extraction, which
    # does not work larger restores that has lots of names.
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))
    shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] *
                                mesh.num_local_devices(),
                                layout_lib.Layout.replicated(mesh.host_mesh(),
                                                             rank=1))
    # A list of TensorShape representing all shapes for the input tensors.
    input_shapes = [
        tensor.shape for tensor in ordered_name_tensor_dict.values()
    ]
    input_layouts = [
        api.fetch_layout(tensor).to_string()
        for tensor in ordered_name_tensor_dict.values()
    ]

    with ops.device(api.device_name()):
        restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
            prefix=checkpoint_prefix,
            tensor_names=tensor_names,
            shape_and_slices=shape_and_slices,
            input_shapes=input_shapes,
            input_layouts=input_layouts,
            dtypes=[
                tensor.dtype for tensor in ordered_name_tensor_dict.values()
            ])

    return collections.OrderedDict(
        zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
Exemple #6
0
    def __init__(self,
                 dataset: dataset_ops.DatasetV2,
                 *,
                 mesh: layout_lib.Mesh,
                 layouts: Any,
                 global_batch_size: int,
                 dataset_already_batched: bool = False,
                 batch_dim: Optional[str] = None,
                 prefetch: Optional[int] = None,
                 tf_data_service_config: Optional[TFDataServiceConfig] = None):
        """Creates a DTensorDataset.

    DTensorDataset automatically handles distribution of the dataset elements to
    each client's devices. It can be used to create an iterator that returns
    DTensors of the input data on each iteration.

    DTensorDataset works best with unbatched datasets. It takes the mesh and the
    provided layouts to automatically calculate how to batch the input locally
    for each replica.

    If the provided dataset is already batched according to the per-replica
    batch size, then `dataset_already_batched` must be set and DTensorDataset
    will check that the batch size is consistent with the intended
    `global_batch_size` using the layout information. Each replica receives a
    separate slice of the global batch, thus the per-replica batch size can be
    computed as the global batch size divided by the number of model replicas.
    For a DTensor mesh, the number of replicas is equal to the size of the
    mesh's batch dimension.

    TODO(b/223275517): add support for input datasets that are already batched
    to the global batch size.

    Args:
      dataset: a `tf.data.Dataset` object.
      mesh: the DTensor mesh to place the dataset batches on.
      layouts: a structure of DTensor layouts to be applied to the input dataset
        values. This can be a single layout or (possibly nested) tuples or
        dictionaries of layouts, and the structure must match the structure of
        the dataset. Either all or none of the layouts should be sharded on the
        batch dimension; having only a subset of layouts batch sharded will not
        work and raises a ValueError.
      global_batch_size: the desired global batch size.
      dataset_already_batched: must be set only if the dataset is already
        batched to the per-replica batch size. The batched dataset must have
        `drop_remainder=True` set since DTensor requires static shapes for
        slicing the input tensors.
      batch_dim: the mesh dimension on which the input's batch dimension is
        sharded. Set to None if the input layouts do not shard on the batch
        dimension.
      prefetch: number of batches to prefetch using Dataset.prefetch.
      tf_data_service_config: if operating in multi-client mode, this config
        specifies the tf.data service configuration to use.

    Raises:
      ValueError: on any of the following situations,
        1. if the structures and ranks of layouts and the dataset do not match.
        2. if the shapes in the dataset's spec are not fully defined.
        3. if batch_dim is specified and all layouts are not batch-sharded.
        4. if per_replica_batch_size is specified for an already batched Dataset
           but it does not match the expected per-replica size based on the
           provided mesh.
      TypeError: if type of structures of layouts and the dataset do not match.
    """
        super().__init__(dataset, dataset_ops.to_variant(dataset))

        self._mesh = mesh
        self._layouts = layouts
        self._batch_dim = batch_dim
        self._prefetch = prefetch
        self._tf_data_service_config = tf_data_service_config

        self._element_spec = dataset.element_spec

        nest.assert_same_structure(self._element_spec, self._layouts)
        flattened_layouts = nest.flatten(self._layouts)
        flattened_elem_spec = nest.flatten(self._element_spec)

        if batch_dim:
            num_global_replicas = mesh.dim_size(batch_dim)
            self._local_replica_ids = list(
                dict.fromkeys(
                    [loc[batch_dim] for loc in mesh.local_device_locations()]))

            for layout in flattened_layouts:
                if batch_dim != layout.sharding_specs[0]:
                    raise ValueError((
                        'batch_dim %s was specified but at least one layout did not '
                        'contain it: %s') % (batch_dim, layout))
        else:
            # Only one replica since there is no sharding on the batch dimension.
            num_global_replicas = 1
            self._local_replica_ids = [0]

        # Validate layout and element spec compatibility, and raise ValueError if
        # invalid.
        _validate_input(flattened_layouts,
                        flattened_elem_spec,
                        dataset_already_batched=dataset_already_batched)

        expected_batch_size = global_batch_size // num_global_replicas
        if not dataset_already_batched:
            self._batched_dataset = dataset.batch(expected_batch_size,
                                                  drop_remainder=True)
        else:
            per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0]
            if per_replica_batch_size != expected_batch_size:
                raise ValueError((
                    'per_replica_batch_size does not matched expected size based on '
                    'the mesh, got %d but expected %d.') %
                                 (per_replica_batch_size, expected_batch_size))
            self._batched_dataset = dataset

        num_global_devices_per_replica = api.num_global_devices(
            mesh.device_type()) // num_global_replicas
        self._num_local_replicas = len(self._local_replica_ids)
        self._num_local_devices_per_replica = mesh.num_local_devices(
        ) // self._num_local_replicas
        # The number of clients each replica is split over.
        self._num_clients_per_replica = (num_global_devices_per_replica //
                                         self._num_local_devices_per_replica)
        # In the case where a replica is split across multiple clients, an offset
        # needs to be added to the index used by the partitioning logic such that
        # the local devices on that client can be correctly matched to slices of the
        # input tensor(s). If replicas are wholly contained within a client, then
        # this offset is always 0.
        self._partition_offset = (api.client_id() %
                                  self._num_clients_per_replica
                                  ) * self._num_local_devices_per_replica

        # Helper data structures used in partitioning the dataset tensors.
        self._all_shard_counts = [
            _shard_counts(layout, batch_dim) for layout in flattened_layouts
        ]
        self._index_matrices = [
            _index_matrix(layout, elem_spec) for layout, elem_spec in zip(
                flattened_layouts, flattened_elem_spec)
        ]