def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray], _gda_fast_path_args: Optional[_GdaFastPathArgs] = None): self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes self._device_buffers = device_buffers # Optionally precomputed for performance. self._gda_fast_path_args = _gda_fast_path_args self._current_process = xb.process_index() if self._gda_fast_path_args is None: self._local_devices = self._global_mesh.local_devices else: self._local_devices = self._gda_fast_path_args.local_devices assert len(device_buffers) == len(self._local_devices) self._local_shards = self._create_local_shards() ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes) assert all(db.shape == ss for db in device_buffers), ( f"Expected shard shape {ss} doesn't match the device buffer " f"shape {device_buffers[0].shape}") dtype = device_buffers[0].dtype assert all(db.dtype == dtype for db in device_buffers), ( "Input arrays to GlobalDeviceArray must have matching dtypes, " f"got: {[db.dtype for db in device_buffers]}") self.dtype = dtype
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray], _gda_fast_path_args: Optional[_GdaFastPathArgs] = None): """Constructor of GlobalDeviceArray class. Args: global_shape: The global shape of the array global_mesh: The global mesh representing devices across multiple processes. mesh_axes: A sequence with length less than or equal to the rank of the global array (i.e. the length of the global shape). Each element can be: * An axis name of `global_mesh`, indicating that the corresponding global array axis is partitioned across the given device axis of `global_mesh`. * A tuple of axis names of `global_mesh`. This is like the above option except the global array axis is partitioned across the product of axes named in the tuple. * None indicating that the corresponding global array axis is not partitioned. For more information, please see: https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec device_buffers: DeviceArrays that are on the local devices of `global_mesh`. """ self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes self._device_buffers = device_buffers # Optionally precomputed for performance. self._gda_fast_path_args = _gda_fast_path_args self._current_process = xb.process_index() self._local_shards = self._create_local_shards() if self._gda_fast_path_args is None: local_devices = self._global_mesh.local_devices else: local_devices = self._gda_fast_path_args.local_devices assert len(device_buffers) == len(local_devices) ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes) assert all(db.shape == ss for db in device_buffers), ( f"Expected shard shape {ss} doesn't match the device buffer " f"shape {device_buffers[0].shape}") dtype = device_buffers[0].dtype assert all(db.dtype == dtype for db in device_buffers), ( "Input arrays to GlobalDeviceArray must have matching dtypes, " f"got: {[db.dtype for db in device_buffers]}") self.dtype = dtype
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray], _gda_fast_path_args: Optional[_GdaFastPathArgs] = None, _enable_checks: bool = True): self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes self._device_buffers = device_buffers # Optionally precomputed for performance. self._gda_fast_path_args = _gda_fast_path_args self._current_process = xb.process_index() if self._gda_fast_path_args is None: self._local_devices = self._global_mesh.local_devices else: self._local_devices = self._gda_fast_path_args.local_devices if _enable_checks or config.jax_enable_checks: for db, ld in safe_zip(device_buffers, self._local_devices): if db.device() != ld: raise ValueError( "The `global_mesh.local_devices` and `device_buffers` device " "order doesn't match. Please use `global_mesh.local_devices` to " "put arrays on devices instead of `jax.local_devices()`" ) if _enable_checks or config.jax_enable_checks: ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes) assert all(db.shape == ss for db in device_buffers), ( f"Expected shard shape {ss} doesn't match the device buffer " f"shape, got: {[db.shape for db in device_buffers]}") dtype = device_buffers[0].dtype if _enable_checks or config.jax_enable_checks: assert all(db.dtype == dtype for db in device_buffers), ( "Input arrays to GlobalDeviceArray must have matching dtypes, " f"got: {[db.dtype for db in device_buffers]}") self.dtype = dtype
def _create_shards( self, device_buffers: Sequence[DeviceArray] ) -> Tuple[Sequence[Shard], Sequence[Shard]]: indices = get_shard_indices(self._global_shape, self._global_mesh, self._mesh_axes) device_to_buffer = dict((db.device(), db) for db in device_buffers) gs, ls = [], [] index_to_replica: Dict[_HashableIndex, int] = Counter() for device, index in indices.items(): h_index = _HashableIndex(index) replica_id = index_to_replica[h_index] index_to_replica[h_index] += 1 local_shard = device.process_index == xb.process_index() buf = device_to_buffer[device] if local_shard else None sh = Shard(device, index, replica_id, buf) gs.append(sh) if local_shard: if sh.data is None: raise ValueError( "Local shard's data field should not be None.") ls.append(sh) return gs, ls
def _addressable_device_assignment(self) -> XLADeviceAssignment: process_index = xb.process_index() return [d for d in self._device_assignment() if d.process_index == process_index]
def addressable_devices(self) -> Set[Device]: """A set of addressable devices by the current process""" process_index = xb.process_index() return {d for d in self.device_set if d.process_index == process_index}