Ejemplo n.º 1
0
def _get_sharding_spec(global_shape, global_mesh, mesh_axes):
    array_mapping = _get_array_mapping(mesh_axes)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    return pxla.mesh_sharding_specs(global_mesh.shape,
                                    global_mesh.axis_names)(aval,
                                                            array_mapping)
Ejemplo n.º 2
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
    array_mapping = _get_array_mapping(mesh_axes)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    return indices  # type: ignore
Ejemplo n.º 3
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
  # Import here to avoid cyclic import error when importing gda in pjit.py.
  from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

  pspec = _canonicalize_mesh_axes(mesh_axes)
  parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
  array_mapping = get_array_mapping(parsed_pspec)
  # The dtype doesn't matter for creating sharding specs.
  aval = core.ShapedArray(global_shape, np.float32)
  sharding_spec = pxla.mesh_sharding_specs(
      global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
  indices = pxla.spec_to_indices(global_shape, sharding_spec)
  return indices  # type: ignore
Ejemplo n.º 4
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    return indices
Ejemplo n.º 5
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat,
                                            indices))  # type: ignore