Esempio n. 1
0
 def _register_mesh(self, mesh: layout_lib.Mesh):
     """Idempotently register `mesh` with the dtensor device."""
     with self._mesh_lock:
         if mesh not in self._meshes:
             _pywrap_dtensor_device.AddMesh(self._device_info,
                                            mesh.to_string(),
                                            self._is_async, False)
             self._meshes.add(mesh)
             if mesh.device_type().upper() == "TPU":
                 logging.info(
                     "Registering virtual 1:1 mapped host mesh %s for mesh %s",
                     mesh.host_mesh().to_string(), mesh.to_string())
                 _pywrap_dtensor_device.AddMesh(
                     self._device_info,
                     mesh.host_mesh().to_string(), self._is_async, True)
                 self._meshes.add(mesh.host_mesh())
                 embedding_host_mesh = self._create_embedding_host_mesh(
                     mesh)
                 if embedding_host_mesh:
                     logging.info(
                         "Registering embedding host mesh %s on each client for mesh %s",
                         embedding_host_mesh.to_string(), mesh.to_string())
                     _pywrap_dtensor_device.AddMesh(
                         self._device_info, embedding_host_mesh.to_string(),
                         self._is_async, False)
                     self._meshes.add(embedding_host_mesh)
Esempio n. 2
0
    def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
        """Sets a default mesh for all ops in the scope.

    Note: This is an internal helper method, which is not user facing api.

    Useful for requesting a specific mesh for ops which would have no inferred
    layout, e.g. tf.zeros.

    Args:
      mesh: A Mesh to be used for ops without Mesh.

    Yields:
      Nothing.
    """
        previous_default = self._current_default_mesh
        self._register_mesh(mesh)
        _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
            self._device_info,
            mesh.to_string().encode("utf-8"))
        self._current_default_mesh = mesh
        yield
        _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
        if previous_default:
            _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
                self._device_info,
                previous_default.to_string().encode("utf-8"))
        self._current_default_mesh = previous_default
Esempio n. 3
0
    def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh):
        """Returns Embedding host mesh for each client."""
        if tpu_mesh.device_type().upper() != "TPU":
            raise ValueError("Must pass input of a tpu mesh.")

        # Global device ids are global host ids, while local device ids contains
        # local host id.

        ts_local_device_ids = []
        ts_local_devices = []
        for local_device_str in tpu_mesh.local_devices():
            # We only need to keep TPU:0 for each client.
            if not local_device_str.endswith("TPU:0"):
                continue

            device_spec = tf_device.DeviceSpec.from_string(local_device_str)
            ts_local_device_ids.append(device_spec.task)
            ts_local_devices.append(device_spec.replace(device_type="CPU"))

        if not ts_local_device_ids or not ts_local_device_ids:
            logging.info(
                "Cannot create tpu system mesh as %s has no `TPU:0` local device "
                "found", tpu_mesh.to_string())
            return None

        ts_global_device_ids = np.arange(self._num_clients())
        # TODO(zhonglinhan): parse global device specs as input when not None.
        return layout_lib.Mesh(
            dim_names=[tpu_mesh.dim_names[0]],  # 1D mesh.
            global_device_ids=ts_global_device_ids,
            local_device_ids=ts_local_device_ids,
            local_devices=ts_local_devices)
Esempio n. 4
0
 def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
     self._register_mesh(mesh)
     _dtensor_device.ExperimentalSetDefaultMesh(
         self._device_info,
         mesh.to_string().encode("utf-8"))
     yield
     _dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
Esempio n. 5
0
def sharded_prefix(
    mesh: layout_lib.Mesh,
    prefix: List[str],
    tensor_names: List[str],
    shape_and_slices: List[str],
    tensors: List[ops.Tensor],
):
    """Generates all sharded prefix in distributed Save.

  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
  and it is desired to not save with same shard_prefix so that content will
  not be overwritten.

  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
  defines a unique set of shard prefix that is generated for all the Save ops.
  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
  is used in save.

  Args:
    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
      is used in Save op.
    prefix: The prefix of saving files.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A one d string tensor that represents all shard_prefix generated.
  """
    layout_str = array_ops.stack(
        [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
    layouts = api.pack([layout_str] * mesh.num_local_devices(),
                       layout_lib.Layout.replicated(mesh, rank=1))

    mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
                               layout_lib.Layout.replicated(mesh, rank=0))
    return gen_dtensor_ops.d_tensor_sharded_prefix(prefix,
                                                   tensor_names,
                                                   shape_and_slices,
                                                   mesh_str_tensor,
                                                   layouts=layouts,
                                                   tensors=tensors)