def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources parsed_spec, _, _, _ = _prepare_axis_resources(self.spec, "spec") array_mapping = get_array_mapping(parsed_spec) # TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding # since we don't really need sharding spec. sharding_spec = pxla.new_mesh_sharding_specs( self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping) return sharding_spec.sharding_proto()
def _get_array_mapping(mesh_axes): # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") return get_array_mapping(parsed_pspec)
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[Index, ...]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources pspec = _canonicalize_mesh_axes(mesh_axes) parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) return indices # type: ignore
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) return indices
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Index]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) # The type: ignore is to ignore the type returned by `spec_to_indices`. return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def _get_array_mapping(mesh_axes): # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes") return get_array_mapping(parsed_pspec)