def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for # their (input) signatures to match. This function "joins" the staged jaxprs: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). jaxprs, all_consts, all_out_trees = \ unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) for fun in funs) newvar = core.gensym(jaxprs, suffix='_') all_const_avals = [map(_abstractify, consts) for consts in all_consts] unused_const_vars = [ map(newvar, const_avals) for const_avals in all_const_avals ] def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i + 1:]) constvars = [*prefix, *jaxpr.constvars, *suffix] return jaxpr.replace(constvars=constvars) consts = util.concatenate(all_consts) jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] closed_jaxprs = [ core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) for jaxpr in jaxprs ] return closed_jaxprs, consts, all_out_trees
def callback_jaxpr(closed_jaxpr, callback, strip_calls): fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr)) fun = callback_subtrace(fun) fun = _callback_fun(fun, callback, strip_calls) avals_in = closed_jaxpr.in_avals jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in) return core.ClosedJaxpr(jaxpr_out, consts)
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals, reduce_axes): # backward_pass can only transpose linear computations, but the call_jaxpr embedded in # remat contains primal (non-linear) equations too. Hence, we have to eliminate those # (in this case via partial_eval) before we call into backward_pass again. typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, []) unknowns = map(is_undefined_primal, primals_in) primal_jaxpr, tangent_jaxpr, out_unknowns = \ pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore def do_transpose(primals_in, cotangents_in): # NOTE: This is passing in undefined primals in place of tangent arguments, but it # should all work out, because we're only computing the primal part here. residuals = core.jaxpr_as_fun(primal_jaxpr)( *primals_in)[len(cotangents_in):] # Now that we have a purely linear jaxpr, we can transpose it cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in) # backward_pass will return cotangents computed for all invars, but some of them # are residuals appended by partial eval, so we need to skip those before we return. return cotangents_out[:len(primals_in)] flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) flat_do_transpose, out_tree = flatten_fun_nokwargs( lu.wrap_init(do_transpose), in_tree_def) flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params) return tree_unflatten(out_tree(), flat_cotangents_out)
def batched_fwd_jaxpr_thunk(): fwd_jaxpr = core.ClosedJaxpr( *fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, fun_jaxpr, num_consts, **params): main = trace.main vals = [t.val for t in tracers] new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls) if primitive == cd.custom_jvp_call_jaxpr_p: thunk_name = 'jvp_jaxpr_thunk' elif primitive == cd.custom_vjp_call_jaxpr_p: thunk_name = 'fwd_jaxpr_thunk' params['bwd'] = callback_subtrace(params['bwd'], main) else: raise NotImplementedError(primitive) thunk = params.pop(thunk_name) @pe._memoize def new_thunk(): thunk_jaxpr = core.ClosedJaxpr(*thunk()) closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls) return closed_jaxpr.jaxpr, closed_jaxpr.literals params[thunk_name] = new_thunk new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ()) new_num_consts = len(new_consts) + num_consts out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr, num_consts=new_num_consts, **params) return safe_map(trace.pure, out)
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr( discharged_jaxpr, body_consts), [False] + nonzero_tangents, instantiate=nonzero_tangents) if out_nonzero_tangents == nonzero_tangents: break nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) else: raise Exception("Invalid fixpoint") tangents = [ ad.instantiate_zeros(t) if inst else t for t, inst in zip(tangents, nonzero_tangents) ] tangents = [t for t in tangents if type(t) is not ad_util.Zero] closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, []) jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear + (True, ) * len(tangents)) out_flat = for_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=jvp_which_linear) # `out_flat` includes constant inputs into the `for_loop` which are # converted into outputs as well. We don't care about these in AD so we # throw them out. _, out_primals, out_tangents = split_list( out_flat, [len(jvp_consts), len(primals)]) out_tangents_iter = iter(out_tangents) out_tangents = [ next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, nonzero_tangents) ] return out_primals, out_tangents
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): f, msgs = checkify_subtrace(f) f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): jaxpr, consts, out_tree = _initial_style_open_jaxpr( fun, in_tree, in_avals, primitive_name) closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts, out_tree
def _initial_style_jaxpr(fun, in_avals): in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, bottom=True, stage_out=False) # type: ignore assert not any(isinstance(c, core.Tracer) for c in consts) return core.ClosedJaxpr(jaxpr, consts)
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest) f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type) avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
def make_transpose_from_thunk(thunk, lin_tree): transpose_jaxpr, transpose_consts = thunk() transpose_jaxpr = core.ClosedJaxpr( pe.convert_constvars_jaxpr(transpose_jaxpr), ()) def transpose(res_arg, ct_out): args_flat = tree_leaves((res_arg, ct_out)) ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat) return tree_unflatten(lin_tree, ct_ins) return transpose
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params): assert not jaxpr.constvars jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars), axis_name=axis_name, main_type=main_type) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f, msgs = check_errors_subtrace(f) f = check_errors_traceable(f, tuple(error.msgs.items())) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *jaxpr.in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def _jvp_jaxpr(jaxpr, nonzeros, instantiate): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) assert not env # TODO: this is from partial_eval.trace_to_jaxpr. Share. closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params): assert not jaxpr.constvars in_batched = [d is not batching.not_mapped for d in dims] jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_batched_, out_batched = batching.batch_jaxpr( jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name, main_type=main_type) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_vars = core.gensym([jaxpr]) new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars) new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts)
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, avals_out, *args): xla.check_backend_matches(backend, ctx.platform) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) sub_ctx = ctx.replace( name_stack=xla.extend_name_stack(ctx.name_stack, stack_name)) symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name, core.ClosedJaxpr(call_jaxpr, ())) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), flatten_lowering_ir_args(args)) return util.unflatten(call.results, map(len, output_types))
def batched_jvp_jaxpr_thunk(): jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers _, args_batched = split_list(in_batched, [num_consts]) _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name, main_type) primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
def augment_jaxpr(jaxpr, res_indices): num_res = len(res_indices) res_vars = jaxpr.jaxpr.invars[:num_res] non_res_vars = jaxpr.jaxpr.invars[num_res:] aug_res_vars = list( util.subvals(all_res_vars, zip(res_indices, res_vars))) aug_invars = aug_res_vars + non_res_vars jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) return jaxpr_aug
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): assert not jaxpr.constvars in_nonzeros = [type(t) is not ad_util.Zero for t in tangents] jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False) nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero] jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr) outs = remat_p.bind( *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp, prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, out_nonzeros)] return out_primals, out_tangents
def custom_jvp_call_jaxpr(fun, jvp, *args): """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive.""" in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) jvp_jaxpr_thunk = pe._memoize( # pylint: disable=protected-access lambda: cd._initial_style_jaxpr(jvp, in_avals * 2)) # pylint: disable=protected-access return cd.custom_jvp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_avals) else: pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] jaxpr, _, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees): in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) fwd_jaxpr_thunk = pe._memoize( lambda: cd._initial_style_jaxpr(fwd, in_avals)) # pylint: disable=protected-access return cd.custom_vjp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees, num_consts=len(consts))
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if not jax.config.omnistaging_enabled: raise ValueError('Oryx must be used with JAX omnistaging enabled.') if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) else: pvals = [ pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals ] jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def __call__(self, residual_arg, linear_arg): res_arg, lin_arg = residual_arg, linear_arg _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_transpose_p.bind(*consts, *args_flat, call=closed_call, rule=self.transpose, lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat)
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)