Exemplo n.º 1
0
    def __init__(self,
                 global_shape: Shape,
                 global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes,
                 device_buffers: Sequence[DeviceArray],
                 _gda_fast_path_args: Optional[_GdaFastPathArgs] = None):
        self._global_shape = global_shape
        self._global_mesh = global_mesh
        self._mesh_axes = mesh_axes
        self._device_buffers = device_buffers
        # Optionally precomputed for performance.
        self._gda_fast_path_args = _gda_fast_path_args
        self._current_process = xb.process_index()

        if self._gda_fast_path_args is None:
            self._local_devices = self._global_mesh.local_devices
        else:
            self._local_devices = self._gda_fast_path_args.local_devices
        assert len(device_buffers) == len(self._local_devices)

        self._local_shards = self._create_local_shards()

        ss = get_shard_shape(self._global_shape, self._global_mesh,
                             self._mesh_axes)
        assert all(db.shape == ss for db in device_buffers), (
            f"Expected shard shape {ss} doesn't match the device buffer "
            f"shape {device_buffers[0].shape}")

        dtype = device_buffers[0].dtype
        assert all(db.dtype == dtype for db in device_buffers), (
            "Input arrays to GlobalDeviceArray must have matching dtypes, "
            f"got: {[db.dtype for db in device_buffers]}")
        self.dtype = dtype
Exemplo n.º 2
0
    def __init__(self,
                 global_shape: Shape,
                 global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes,
                 device_buffers: Sequence[DeviceArray],
                 _gda_fast_path_args: Optional[_GdaFastPathArgs] = None):
        """Constructor of GlobalDeviceArray class.

    Args:
      global_shape: The global shape of the array
      global_mesh: The global mesh representing devices across multiple
        processes.
      mesh_axes: A sequence with length less than or equal to the rank of the
      global array (i.e. the length of the global shape). Each element can be:
        * An axis name of `global_mesh`, indicating that the corresponding
          global array axis is partitioned across the given device axis of
          `global_mesh`.
        * A tuple of axis names of `global_mesh`. This is like the above option
          except the global array axis is partitioned across the product of axes
          named in the tuple.
        * None indicating that the corresponding global array axis is not
          partitioned.
        For more information, please see:
        https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec
      device_buffers: DeviceArrays that are on the local devices of
        `global_mesh`.
    """
        self._global_shape = global_shape
        self._global_mesh = global_mesh
        self._mesh_axes = mesh_axes
        self._device_buffers = device_buffers
        # Optionally precomputed for performance.
        self._gda_fast_path_args = _gda_fast_path_args
        self._current_process = xb.process_index()
        self._local_shards = self._create_local_shards()

        if self._gda_fast_path_args is None:
            local_devices = self._global_mesh.local_devices
        else:
            local_devices = self._gda_fast_path_args.local_devices
        assert len(device_buffers) == len(local_devices)

        ss = get_shard_shape(self._global_shape, self._global_mesh,
                             self._mesh_axes)
        assert all(db.shape == ss for db in device_buffers), (
            f"Expected shard shape {ss} doesn't match the device buffer "
            f"shape {device_buffers[0].shape}")

        dtype = device_buffers[0].dtype
        assert all(db.dtype == dtype for db in device_buffers), (
            "Input arrays to GlobalDeviceArray must have matching dtypes, "
            f"got: {[db.dtype for db in device_buffers]}")
        self.dtype = dtype
Exemplo n.º 3
0
    def __init__(self,
                 global_shape: Shape,
                 global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes,
                 device_buffers: Sequence[DeviceArray],
                 _gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
                 _enable_checks: bool = True):
        self._global_shape = global_shape
        self._global_mesh = global_mesh
        self._mesh_axes = mesh_axes
        self._device_buffers = device_buffers
        # Optionally precomputed for performance.
        self._gda_fast_path_args = _gda_fast_path_args
        self._current_process = xb.process_index()

        if self._gda_fast_path_args is None:
            self._local_devices = self._global_mesh.local_devices
        else:
            self._local_devices = self._gda_fast_path_args.local_devices

        if _enable_checks or config.jax_enable_checks:
            for db, ld in safe_zip(device_buffers, self._local_devices):
                if db.device() != ld:
                    raise ValueError(
                        "The `global_mesh.local_devices` and `device_buffers` device "
                        "order doesn't match. Please use `global_mesh.local_devices` to "
                        "put arrays on devices instead of `jax.local_devices()`"
                    )

        if _enable_checks or config.jax_enable_checks:
            ss = get_shard_shape(self._global_shape, self._global_mesh,
                                 self.mesh_axes)
            assert all(db.shape == ss for db in device_buffers), (
                f"Expected shard shape {ss} doesn't match the device buffer "
                f"shape, got: {[db.shape for db in device_buffers]}")

        dtype = device_buffers[0].dtype
        if _enable_checks or config.jax_enable_checks:
            assert all(db.dtype == dtype for db in device_buffers), (
                "Input arrays to GlobalDeviceArray must have matching dtypes, "
                f"got: {[db.dtype for db in device_buffers]}")
        self.dtype = dtype
Exemplo n.º 4
0
 def _create_shards(
     self, device_buffers: Sequence[DeviceArray]
 ) -> Tuple[Sequence[Shard], Sequence[Shard]]:
     indices = get_shard_indices(self._global_shape, self._global_mesh,
                                 self._mesh_axes)
     device_to_buffer = dict((db.device(), db) for db in device_buffers)
     gs, ls = [], []
     index_to_replica: Dict[_HashableIndex, int] = Counter()
     for device, index in indices.items():
         h_index = _HashableIndex(index)
         replica_id = index_to_replica[h_index]
         index_to_replica[h_index] += 1
         local_shard = device.process_index == xb.process_index()
         buf = device_to_buffer[device] if local_shard else None
         sh = Shard(device, index, replica_id, buf)
         gs.append(sh)
         if local_shard:
             if sh.data is None:
                 raise ValueError(
                     "Local shard's data field should not be None.")
             ls.append(sh)
     return gs, ls
Exemplo n.º 5
0
 def _addressable_device_assignment(self) -> XLADeviceAssignment:
   process_index = xb.process_index()
   return [d for d in self._device_assignment() if d.process_index == process_index]
Exemplo n.º 6
0
 def addressable_devices(self) -> Set[Device]:
   """A set of addressable devices by the current process"""
   process_index = xb.process_index()
   return {d for d in self.device_set if d.process_index == process_index}