def testUnmaterializedAxis(self): shape = (4, 8) spec = pxla.ShardingSpec(shards_per_axis=(4, 1), is_axis_materialized=(False, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 1, 2, 3)) shape = (2, 2) spec = pxla.ShardingSpec(shards_per_axis=(1, 2), is_axis_materialized=(True, False), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), ((slice(None), 0), (slice(None), 1)))
def testReplication(self): shape = (2, 8) spec = pxla.ShardingSpec(shards_per_axis=(2, 1), is_axis_materialized=(False, True), replication_factor=3) self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 0, 0, 1, 1, 1))
def testNoSharding(self): shape = (4,8) spec = pxla.ShardingSpec(shards_per_axis=(1, 1), is_axis_materialized=(True, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (slice(None),))
def testUnshardedAxis(self): shape = (4, 8) spec = pxla.ShardingSpec(shards_per_axis=(2, 1), is_axis_materialized=(True, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (slice(0,2), (slice(2,4))))
def _aval_to_result_handler(npart, parts, aval): if aval is not core.abstract_unit: spec = pxla.partitioned_sharding_spec(npart, parts, aval) indices = pxla.spec_to_indices(aval.shape, spec) else: spec = indices = None return pxla.local_aval_to_result_handler(aval, spec, indices)
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[Index, ...]: array_mapping = _get_array_mapping(mesh_axes) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) return indices # type: ignore
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[Index, ...]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources pspec = _canonicalize_mesh_axes(mesh_axes) parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) return indices # type: ignore
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) return indices
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Index]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) # The type: ignore is to ignore the type returned by `spec_to_indices`. return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[Index, ...]: sharding_spec = _get_sharding_spec(global_shape, global_mesh, mesh_axes) indices = pxla.spec_to_indices(global_shape, sharding_spec) return indices # type: ignore
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def _aval_to_result_handler(npart, parts, aval): spec = pxla.partitioned_sharding_spec(npart, parts, aval) indices = pxla.spec_to_indices(aval.shape, spec) return pxla.local_aval_to_result_handler(aval, spec, indices)
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform if platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError(f"sharded_jit not supported for {platform}") nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: indices = pxla.spec_to_indices(global_shape, self.sharding_spec) return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore