def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Index]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) # The type: ignore is to ignore the type returned by `spec_to_indices`. return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def _get_array_mapping(mesh_axes): # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes") return get_array_mapping(parsed_pspec)