예제 #1
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat,
                                            indices))  # type: ignore
예제 #2
0
def _get_array_mapping(mesh_axes):
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
    return get_array_mapping(parsed_pspec)