コード例 #1
0
ファイル: pmap_test.py プロジェクト: ziyadedher/jax
    def testUnmaterializedAxis(self):
        shape = (4, 8)
        spec = pxla.ShardingSpec(shards_per_axis=(4, 1),
                                 is_axis_materialized=(False, True),
                                 replication_factor=1)
        self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 1, 2, 3))

        shape = (2, 2)
        spec = pxla.ShardingSpec(shards_per_axis=(1, 2),
                                 is_axis_materialized=(True, False),
                                 replication_factor=1)
        self.assertEqual(pxla.spec_to_indices(shape, spec),
                         ((slice(None), 0), (slice(None), 1)))
コード例 #2
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testReplication(self):
   shape = (2, 8)
   spec = pxla.ShardingSpec(shards_per_axis=(2, 1),
                            is_axis_materialized=(False, True),
                            replication_factor=3)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (0, 0, 0, 1, 1, 1))
コード例 #3
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testNoSharding(self):
   shape = (4,8)
   spec = pxla.ShardingSpec(shards_per_axis=(1, 1),
                            is_axis_materialized=(True, True),
                            replication_factor=1)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (slice(None),))
コード例 #4
0
ファイル: pmap_test.py プロジェクト: mattwescott/jax
 def testUnshardedAxis(self):
   shape = (4, 8)
   spec = pxla.ShardingSpec(shards_per_axis=(2, 1),
                            is_axis_materialized=(True, True),
                            replication_factor=1)
   self.assertEqual(pxla.spec_to_indices(shape, spec),
                    (slice(0,2), (slice(2,4))))
コード例 #5
0
def _aval_to_result_handler(npart, parts, aval):
    if aval is not core.abstract_unit:
        spec = pxla.partitioned_sharding_spec(npart, parts, aval)
        indices = pxla.spec_to_indices(aval.shape, spec)
    else:
        spec = indices = None
    return pxla.local_aval_to_result_handler(aval, spec, indices)
コード例 #6
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
    array_mapping = _get_array_mapping(mesh_axes)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    return indices  # type: ignore
コード例 #7
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
  # Import here to avoid cyclic import error when importing gda in pjit.py.
  from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

  pspec = _canonicalize_mesh_axes(mesh_axes)
  parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
  array_mapping = get_array_mapping(parsed_pspec)
  # The dtype doesn't matter for creating sharding specs.
  aval = core.ShapedArray(global_shape, np.float32)
  sharding_spec = pxla.mesh_sharding_specs(
      global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
  indices = pxla.spec_to_indices(global_shape, sharding_spec)
  return indices  # type: ignore
コード例 #8
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    return indices
コード例 #9
0
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                      mesh_axes: MeshAxes) -> Mapping[Device, Index]:
    # Import here to avoid cyclic import error when importing gda in pjit.py.
    from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

    if not isinstance(mesh_axes, PartitionSpec):
        pspec = PartitionSpec(*mesh_axes)
    else:
        pspec = mesh_axes
    parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
    array_mapping = get_array_mapping(parsed_pspec)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    for index in indices:
        assert isinstance(index, tuple)
        for idx in index:
            assert isinstance(idx, slice)
    # The type: ignore is to ignore the type returned by `spec_to_indices`.
    return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat,
                                            indices))  # type: ignore
コード例 #10
0
ファイル: global_device_array.py プロジェクト: jbampton/jax
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
    sharding_spec = _get_sharding_spec(global_shape, global_mesh, mesh_axes)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    return indices  # type: ignore
コード例 #11
0
ファイル: sharded_jit.py プロジェクト: xueeinstein/jax
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
コード例 #12
0
ファイル: sharded_jit.py プロジェクト: xueeinstein/jax
def _aval_to_result_handler(npart, parts, aval):
    spec = pxla.partitioned_sharding_spec(npart, parts, aval)
    indices = pxla.spec_to_indices(aval.shape, spec)
    return pxla.local_aval_to_result_handler(aval, spec, indices)
コード例 #13
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform
    if platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError(f"sharded_jit not supported for {platform}")

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    ctx = xla.TranslationContext(
        c, platform, axis_env,
        extend_name_stack(wrap_name(name, "sharded_jit")))
    out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
    out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
コード例 #14
0
 def devices_indices_map(
     self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
   indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
   return {d: i for d, i in safe_zip(self.devices.flat, indices)}  # type: ignore