예제 #1
0
def get_device_locations(
    mesh: layout_lib.Mesh,
    client_id: Optional[int] = None) -> List[Dict[str, int]]:
  """Returns the device locations of all TPU cores local to the given client.

  A device location is a dictionary from dimension names to indices on those
  dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a
  permutation of this list:

    [{'x': 0, 'y': 0},
     {'x': 0, 'y': 1},
     {'x': 1, 'y': 0},
     {'x': 1, 'y': 1}].

  Note that device IDs and device locations are equivalent. The former is a
  linearization of the latter along mesh dimensions.

  Args:
    mesh: A TPU mesh.
    client_id: Optional; A DTensor client ID. If empty, query this client.
  """

  if mesh.device_type() != _TPU_DEVICE_TYPE:
    raise ValueError("The mesh must be a TPU mesh")

  if client_id is None or client_id == api.client_id():
    return mesh.local_device_locations()

  # It's not clear we should ever allow a client to query other clients for
  # their device locations.
  raise NotImplementedError(
      "Looking up other clients' device locations is not supported")
예제 #2
0
    def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh):
        """Returns Embedding host mesh for each client."""
        if tpu_mesh.device_type().upper() != "TPU":
            raise ValueError("Must pass input of a tpu mesh.")

        # Global device ids are global host ids, while local device ids contains
        # local host id.

        ts_local_device_ids = []
        ts_local_devices = []
        for local_device_str in tpu_mesh.local_devices():
            # We only need to keep TPU:0 for each client.
            if not local_device_str.endswith("TPU:0"):
                continue

            device_spec = tf_device.DeviceSpec.from_string(local_device_str)
            ts_local_device_ids.append(device_spec.task)
            ts_local_devices.append(device_spec.replace(device_type="CPU"))

        if not ts_local_device_ids or not ts_local_device_ids:
            logging.info(
                "Cannot create tpu system mesh as %s has no `TPU:0` local device "
                "found", tpu_mesh.to_string())
            return None

        ts_global_device_ids = np.arange(self._num_clients())
        # TODO(zhonglinhan): parse global device specs as input when not None.
        return layout_lib.Mesh(
            dim_names=[tpu_mesh.dim_names[0]],  # 1D mesh.
            global_device_ids=ts_global_device_ids,
            local_device_ids=ts_local_device_ids,
            local_devices=ts_local_devices)
예제 #3
0
def get_device_ids(mesh: layout_lib.Mesh,
                   client_id: Optional[int] = None) -> List[int]:
  """Returns the device IDs of all TPU cores local to the given client.

  A device ID is a non-negative integer that uniquely identifies a device in the
  mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a
  permutation of [0, 1, 2, 3].

  Note that device IDs and device locations are equivalent. The former is a
  linearization of the latter along mesh dimensions.

  Args:
    mesh: A TPU mesh.
    client_id: Optional; A DTensor client ID. If empty, query this client.
  """

  if mesh.device_type() != _TPU_DEVICE_TYPE:
    raise ValueError("The mesh must be a TPU mesh")

  if client_id is None or client_id == api.client_id():
    return mesh.local_device_ids()

  # It's not clear we should ever allow a client to query other clients for
  # their device IDs.
  raise NotImplementedError(
      "Looking up other clients' device IDs is not supported")
예제 #4
0
def sharded_save(
    mesh: layout_lib.Mesh,
    file_prefix: Union[str, ops.Tensor],
    tensor_names: Union[List[str], ops.Tensor],
    shape_and_slices: Union[List[str], ops.Tensor],
    tensors: List[Union[ops.Tensor, tf_variables.Variable]],
):
    """Saves given named tensor slices in a sharded, multi-client safe fashion.

  The method makes sure the checkpoint directory state is correct in a sharded
  mutli-client saving. Namely, we place a barrier after SaveV2 to make sure
  every client has done writing the files. And another one after
  MergeV2Checkpoints to make sure all Metadata is properly merged.

  Upon existing, the checkpoint is completed and the all directory operations
  are done.

  Args:
    mesh: The Mesh that contains the Tensors to save.
    file_prefix: The prefix of checkpoint.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A MergeV2Checkpoints op that merged all Metadata.
  """
    with ops.device(api.device_name()):
        io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors)

    # Query generated shards and generate MergeV2.
    generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix],
                                      tensor_names, shape_and_slices, tensors)
    # api.py is still visible to external users but the _global_barrier() isn't
    # intended for public usage.
    # Once we locked down api.py visibility, we shall be able to make the `_`
    # prefix on these APIs go away.

    # Make sure all clients have written the files
    _global_barrier(mesh.host_mesh(), 'SaveV2')  # pylint: disable=protected-access

    with ops.device(api.device_name()):
        merge_op = io_ops.MergeV2Checkpoints(
            checkpoint_prefixes=generated_shards,
            destination_prefix=file_prefix,
            delete_old_dirs=True)

    # Make sure first device in first host has finished merge.
    # pylint: disable=protected-access
    _global_barrier(mesh.host_mesh(), 'MergeV2Checkpoints')
    # pylint: enable=protected-access

    return merge_op
예제 #5
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
예제 #6
0
    def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
        """Sets a default mesh for all ops in the scope.

    Note: This is an internal helper method, which is not user facing api.

    Useful for requesting a specific mesh for ops which would have no inferred
    layout, e.g. tf.zeros.

    Args:
      mesh: A Mesh to be used for ops without Mesh.

    Yields:
      Nothing.
    """
        previous_default = self._current_default_mesh
        self._register_mesh(mesh)
        _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
            self._device_info,
            mesh.to_string().encode("utf-8"))
        self._current_default_mesh = mesh
        yield
        _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
        if previous_default:
            _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
                self._device_info,
                previous_default.to_string().encode("utf-8"))
        self._current_default_mesh = previous_default
예제 #7
0
 def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
     self._register_mesh(mesh)
     _dtensor_device.ExperimentalSetDefaultMesh(
         self._device_info,
         mesh.to_string().encode("utf-8"))
     yield
     _dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
예제 #8
0
 def _register_mesh(self, mesh: layout_lib.Mesh):
     """Idempotently register `mesh` with the dtensor device."""
     with self._mesh_lock:
         if mesh not in self._meshes:
             _pywrap_dtensor_device.AddMesh(self._device_info,
                                            mesh.to_string(),
                                            self._is_async, False)
             self._meshes.add(mesh)
             if mesh.device_type().upper() == "TPU":
                 logging.info(
                     "Registering virtual 1:1 mapped host mesh %s for mesh %s",
                     mesh.host_mesh().to_string(), mesh.to_string())
                 _pywrap_dtensor_device.AddMesh(
                     self._device_info,
                     mesh.host_mesh().to_string(), self._is_async, True)
                 self._meshes.add(mesh.host_mesh())
                 embedding_host_mesh = self._create_embedding_host_mesh(
                     mesh)
                 if embedding_host_mesh:
                     logging.info(
                         "Registering embedding host mesh %s on each client for mesh %s",
                         embedding_host_mesh.to_string(), mesh.to_string())
                     _pywrap_dtensor_device.AddMesh(
                         self._device_info, embedding_host_mesh.to_string(),
                         self._is_async, False)
                     self._meshes.add(embedding_host_mesh)
예제 #9
0
def sharded_prefix(
    mesh: layout_lib.Mesh,
    prefix: List[str],
    tensor_names: List[str],
    shape_and_slices: List[str],
    tensors: List[ops.Tensor],
):
    """Generates all sharded prefix in distributed Save.

  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
  and it is desired to not save with same shard_prefix so that content will
  not be overwritten.

  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
  defines a unique set of shard prefix that is generated for all the Save ops.
  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
  is used in save.

  Args:
    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
      is used in Save op.
    prefix: The prefix of saving files.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A one d string tensor that represents all shard_prefix generated.
  """
    layout_str = array_ops.stack(
        [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
    layouts = api.pack([layout_str] * mesh.num_local_devices(),
                       layout_lib.Layout.replicated(mesh, rank=1))

    mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
                               layout_lib.Layout.replicated(mesh, rank=0))
    return gen_dtensor_ops.d_tensor_sharded_prefix(prefix,
                                                   tensor_names,
                                                   shape_and_slices,
                                                   mesh_str_tensor,
                                                   layouts=layouts,
                                                   tensors=tensors)
예제 #10
0
def name_based_save(mesh: layout_lib.Mesh,
                    checkpoint_prefix: Union[str, ops.Tensor],
                    name_tensor_dict: Dict[str, Union[ops.Tensor,
                                                      tf_variables.Variable]]):
    """Saves name based Tensor into a Checkpoint.

  The function prepares the input dictionary to the format of a `sharded_save`,
  so that it can take advantage of DTensor SPMD based distributed save.

  Same as restore, the function only supports saving on the single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.
  """
    if not context.executing_eagerly():
        raise ValueError('name based save must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Current _dtensor_device() in api.py is the correct way of specifying
    # DTensor device singletons. The API itself will be eventually be moved to
    # a public API and provides global singleton in DTensor context.
    # For now, we just use the current `internal` API and aim at migrating in
    # one shot later.
    # TODO(hthu): Provide _dtensor_device() singleton as a public API.
    # pylint: disable=protected-access
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))

    sharded_save(mesh,
                 file_prefix=checkpoint_prefix,
                 tensor_names=tensor_names,
                 shape_and_slices=[''] * len(ordered_name_tensor_dict),
                 tensors=list(ordered_name_tensor_dict.values()))
예제 #11
0
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str):
    """Runs a global barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients.

  Currently we allocate a fully sharded tensor with mesh shape and run a
  all_reduce on it.

  Args:
    mesh: The mesh to run the global barrier on.
    last_op_name: The last op run before the global_barrier. mainly used for
      logging purpose.
  """
    logging.info('entering global barrier before op: %s', last_op_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    shape = api._dtensor_device().pack(  # pylint: disable=protected-access
        [mesh.shape()] * mesh.num_local_devices(),
        layout_lib.Layout.replicated(mesh, rank=1))
    ones = api.call_with_layout(array_ops.ones,
                                layout_lib.Layout(mesh.dim_names, mesh),
                                shape=shape,
                                dtype=dtypes.float32)
    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info(
        'finished running global barrier across all clients after '
        'op: %s', last_op_name)
예제 #12
0
def name_based_restore(
    mesh: layout_lib.Mesh,
    checkpoint_prefix: str,
    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
):
    """Restores from checkpoint_prefix to name based DTensors.

  It is required to have already-initialized DTensor variables that have same
  shape/dtype for the tensors being restored.

  Also, we currently only support a named based restore on a single mesh.

  Args:
    mesh: The single mesh that all Tensors would be restored to.
    checkpoint_prefix : The prefix of checkpoint to be restored.
    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
      DTensor shape/dtype must match the tensors being saved/restored for now.

  Returns:
    A dictionary of name to its restored DTensor value.
  """
    if not context.executing_eagerly():
        raise ValueError('name based restore must run eagerly.')

    ordered_name_tensor_dict = name_tensor_dict
    if not isinstance(name_tensor_dict, collections.OrderedDict):
        ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)

    # Make sure that all tensors are on CPU mesh for now.
    # This might not be a hard limitation in the future.
    for name, tensor in ordered_name_tensor_dict.items():
        try:
            if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
                raise ValueError(
                    'Restoring a non CPU Tensor is not supported currently. Offending '
                    'tensor name : {tensor_name}'.format(tensor_name=name))
        except errors_impl.OpError as op_error:
            raise ValueError(
                'Saving/Restoring tensor must be a DTensor') from op_error

    # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
    checkpoint_prefix = api.pack([checkpoint_prefix] *
                                 mesh.num_local_devices(),
                                 layout_lib.Layout.replicated(mesh.host_mesh(),
                                                              rank=0))
    # Explicitly pack to mesh to avoid implicit small constant extraction, which
    # does not work larger restores that has lots of names.
    tensor_names = api.pack([list(ordered_name_tensor_dict.keys())] *
                            mesh.num_local_devices(),
                            layout_lib.Layout.replicated(mesh.host_mesh(),
                                                         rank=1))
    shape_and_slices = api.pack([[''] * len(ordered_name_tensor_dict)] *
                                mesh.num_local_devices(),
                                layout_lib.Layout.replicated(mesh.host_mesh(),
                                                             rank=1))
    # A list of TensorShape representing all shapes for the input tensors.
    input_shapes = [
        tensor.shape for tensor in ordered_name_tensor_dict.values()
    ]
    input_layouts = [
        api.fetch_layout(tensor).to_string()
        for tensor in ordered_name_tensor_dict.values()
    ]

    with ops.device(api.device_name()):
        restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
            prefix=checkpoint_prefix,
            tensor_names=tensor_names,
            shape_and_slices=shape_and_slices,
            input_shapes=input_shapes,
            input_layouts=input_layouts,
            dtypes=[
                tensor.dtype for tensor in ordered_name_tensor_dict.values()
            ])

    return collections.OrderedDict(
        zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
예제 #13
0
    def __init__(self,
                 dataset: dataset_ops.DatasetV2,
                 *,
                 mesh: layout_lib.Mesh,
                 layouts: Any,
                 global_batch_size: int,
                 dataset_already_batched: bool = False,
                 batch_dim: Optional[str] = None,
                 prefetch: Optional[int] = None,
                 tf_data_service_config: Optional[TFDataServiceConfig] = None):
        """Creates a DTensorDataset.

    DTensorDataset automatically handles distribution of the dataset elements to
    each client's devices. It can be used to create an iterator that returns
    DTensors of the input data on each iteration.

    DTensorDataset works best with unbatched datasets. It takes the mesh and the
    provided layouts to automatically calculate how to batch the input locally
    for each replica.

    If the provided dataset is already batched according to the per-replica
    batch size, then `dataset_already_batched` must be set and DTensorDataset
    will check that the batch size is consistent with the intended
    `global_batch_size` using the layout information. Each replica receives a
    separate slice of the global batch, thus the per-replica batch size can be
    computed as the global batch size divided by the number of model replicas.
    For a DTensor mesh, the number of replicas is equal to the size of the
    mesh's batch dimension.

    TODO(b/223275517): add support for input datasets that are already batched
    to the global batch size.

    Args:
      dataset: a `tf.data.Dataset` object.
      mesh: the DTensor mesh to place the dataset batches on.
      layouts: a structure of DTensor layouts to be applied to the input dataset
        values. This can be a single layout or (possibly nested) tuples or
        dictionaries of layouts, and the structure must match the structure of
        the dataset. Either all or none of the layouts should be sharded on the
        batch dimension; having only a subset of layouts batch sharded will not
        work and raises a ValueError.
      global_batch_size: the desired global batch size.
      dataset_already_batched: must be set only if the dataset is already
        batched to the per-replica batch size. The batched dataset must have
        `drop_remainder=True` set since DTensor requires static shapes for
        slicing the input tensors.
      batch_dim: the mesh dimension on which the input's batch dimension is
        sharded. Set to None if the input layouts do not shard on the batch
        dimension.
      prefetch: number of batches to prefetch using Dataset.prefetch.
      tf_data_service_config: if operating in multi-client mode, this config
        specifies the tf.data service configuration to use.

    Raises:
      ValueError: on any of the following situations,
        1. if the structures and ranks of layouts and the dataset do not match.
        2. if the shapes in the dataset's spec are not fully defined.
        3. if batch_dim is specified and all layouts are not batch-sharded.
        4. if per_replica_batch_size is specified for an already batched Dataset
           but it does not match the expected per-replica size based on the
           provided mesh.
      TypeError: if type of structures of layouts and the dataset do not match.
    """
        super().__init__(dataset, dataset_ops.to_variant(dataset))

        self._mesh = mesh
        self._layouts = layouts
        self._batch_dim = batch_dim
        self._prefetch = prefetch
        self._tf_data_service_config = tf_data_service_config

        self._element_spec = dataset.element_spec

        nest.assert_same_structure(self._element_spec, self._layouts)
        flattened_layouts = nest.flatten(self._layouts)
        flattened_elem_spec = nest.flatten(self._element_spec)

        if batch_dim:
            num_global_replicas = mesh.dim_size(batch_dim)
            self._local_replica_ids = list(
                dict.fromkeys(
                    [loc[batch_dim] for loc in mesh.local_device_locations()]))

            for layout in flattened_layouts:
                if batch_dim != layout.sharding_specs[0]:
                    raise ValueError((
                        'batch_dim %s was specified but at least one layout did not '
                        'contain it: %s') % (batch_dim, layout))
        else:
            # Only one replica since there is no sharding on the batch dimension.
            num_global_replicas = 1
            self._local_replica_ids = [0]

        # Validate layout and element spec compatibility, and raise ValueError if
        # invalid.
        _validate_input(flattened_layouts,
                        flattened_elem_spec,
                        dataset_already_batched=dataset_already_batched)

        expected_batch_size = global_batch_size // num_global_replicas
        if not dataset_already_batched:
            self._batched_dataset = dataset.batch(expected_batch_size,
                                                  drop_remainder=True)
        else:
            per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0]
            if per_replica_batch_size != expected_batch_size:
                raise ValueError((
                    'per_replica_batch_size does not matched expected size based on '
                    'the mesh, got %d but expected %d.') %
                                 (per_replica_batch_size, expected_batch_size))
            self._batched_dataset = dataset

        num_global_devices_per_replica = api.num_global_devices(
            mesh.device_type()) // num_global_replicas
        self._num_local_replicas = len(self._local_replica_ids)
        self._num_local_devices_per_replica = mesh.num_local_devices(
        ) // self._num_local_replicas
        # The number of clients each replica is split over.
        self._num_clients_per_replica = (num_global_devices_per_replica //
                                         self._num_local_devices_per_replica)
        # In the case where a replica is split across multiple clients, an offset
        # needs to be added to the index used by the partitioning logic such that
        # the local devices on that client can be correctly matched to slices of the
        # input tensor(s). If replicas are wholly contained within a client, then
        # this offset is always 0.
        self._partition_offset = (api.client_id() %
                                  self._num_clients_per_replica
                                  ) * self._num_local_devices_per_replica

        # Helper data structures used in partitioning the dataset tensors.
        self._all_shard_counts = [
            _shard_counts(layout, batch_dim) for layout in flattened_layouts
        ]
        self._index_matrices = [
            _index_matrix(layout, elem_spec) for layout, elem_spec in zip(
                flattened_layouts, flattened_elem_spec)
        ]