示例#1
0
  def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    parsed_spec, _, _, _ = _prepare_axis_resources(self.spec, "spec")
    array_mapping = get_array_mapping(parsed_spec)
    # TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding
    # since we don't really need sharding spec.
    sharding_spec = pxla.new_mesh_sharding_specs(
        self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
    return sharding_spec.sharding_proto()
示例#2
0
def _get_array_mapping(mesh_axes):
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    return get_array_mapping(parsed_pspec)
示例#3
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
  # Import here to avoid cyclic import error when importing gda in pjit.py.
  from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

  pspec = _canonicalize_mesh_axes(mesh_axes)
  parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
  array_mapping = get_array_mapping(parsed_pspec)
  # The dtype doesn't matter for creating sharding specs.
  aval = core.ShapedArray(global_shape, np.float32)
  sharding_spec = pxla.mesh_sharding_specs(
      global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
  indices = pxla.spec_to_indices(global_shape, sharding_spec)
  return indices  # type: ignore
示例#4
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    return indices
示例#5
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat,
                                            indices))  # type: ignore
示例#6
0
def _get_array_mapping(mesh_axes):
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
    return get_array_mapping(parsed_pspec)