def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for # their (input) signatures to match. This function "joins" the staged jaxprs: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). jaxprs, all_consts, all_out_trees = \ unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) for fun in funs) newvar = core.gensym(jaxprs, suffix='_') all_const_avals = [map(_abstractify, consts) for consts in all_consts] unused_const_vars = [ map(newvar, const_avals) for const_avals in all_const_avals ] def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i + 1:]) constvars = [*prefix, *jaxpr.constvars, *suffix] return jaxpr.replace(constvars=constvars) consts = util.concatenate(all_consts) jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] closed_jaxprs = [ core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) for jaxpr in jaxprs ] return closed_jaxprs, consts, all_out_trees
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] global_axis_sizes = params['global_axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(global_axis_sizes.items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() axis_resource_count = _get_axis_resource_count(params['axis_resources'], params['resource_env']) local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items()} out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes']) source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes'] new_donated_invars = (False,) * len(consts) + params['donated_invars'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, donated_invars=new_donated_invars, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def _sparsify_jaxpr(spenv, jaxpr, *argspecs): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): nonlocal out_tree args = tree_unflatten(in_tree, args_flat) argspecs = arrays_to_argspecs(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv) out = argspecs_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = argspecs_to_arrays(spenv, argspecs) args_flat, in_tree = tree_flatten(args) avals_flat = [ core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat ] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts) return sp_jaxpr, out_tree
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, fun_jaxpr, num_consts, **params): main = trace.main vals = [t.val for t in tracers] new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls) if primitive == cd.custom_jvp_call_jaxpr_p: thunk_name = 'jvp_jaxpr_thunk' elif primitive == cd.custom_vjp_call_jaxpr_p: thunk_name = 'fwd_jaxpr_thunk' params['bwd'] = callback_subtrace(params['bwd'], main) else: raise NotImplementedError(primitive) thunk = params.pop(thunk_name) @pe._memoize def new_thunk(): thunk_jaxpr = core.ClosedJaxpr(*thunk()) closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls) return closed_jaxpr.jaxpr, closed_jaxpr.literals params[thunk_name] = new_thunk new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ()) new_num_consts = len(new_consts) + num_consts out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr, num_consts=new_num_consts, **params) return safe_map(trace.pure, out)
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): assert not jaxpr.constvars cell = lambda: None @lu.wrap_init def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts_ = iter(in_cts) in_cts = [next(in_cts_) if ad.is_undefined_primal(x) else ad_util.Zero(x.aval) for x in in_primals] assert next(in_cts_, None) is None in_cts, cell.treedef = tree_flatten(in_cts) return in_cts args, treedef = tree_flatten((in_primals, out_cts)) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_) in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) return tree_unflatten(cell.treedef, in_cts) # type: ignore
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] axis_sizes = params['axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(params['axis_sizes'].items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] update_params = call_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def _bound_output_tracers(self, primitive, params, jaxpr, consts, env, in_tracers, out_pvs, out_consts, out_keys, name, is_map): """Takes a traced function and binds the Jaxpr to output tracers.""" lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr) const_tracers = safe_map(self.new_instantiated_const, consts) env_tracers = safe_map(self.instantiate_const, env) out_tracers = [ UnzipTracer(self, pe.PartialVal((pv, const)), None, key) for pv, const, key in safe_zip(out_pvs, out_consts, out_keys) ] new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr) if 'donated_invars' in params: new_donated_invars = ( (False, ) * len(const_tracers) + (False, ) * len(env_tracers) + tuple(v for v, t in zip(params['donated_invars'], in_tracers) if not t.pval.is_known())) new_params['donated_invars'] = new_donated_invars if is_map: out_axes = params['out_axes_thunk']() assert all(out_axis == 0 for out_axis in out_axes) new_params['out_axes'] = (0, ) * len(out_tracers) del new_params['out_axes_thunk'] eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive, new_params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn return out_tracers
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): jaxpr, consts, out_tree = _initial_style_open_jaxpr( fun, in_tree, in_avals, primitive_name) closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts, out_tree
def make_transpose_from_thunk(thunk, lin_tree): transpose_jaxpr, transpose_consts = thunk() transpose_jaxpr = core.ClosedJaxpr( pe.convert_constvars_jaxpr(transpose_jaxpr), ()) def transpose(res_arg, ct_out): args_flat = tree_leaves((res_arg, ct_out)) ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat) return tree_unflatten(lin_tree, ct_ins) return transpose
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) assert not env # TODO: this is from partial_eval.trace_to_jaxpr. Share. closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts
def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(fun, in_tree, False, "checkpoint") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) out_flat = remat_p.bind( *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), prevent_cse=prevent_cse, differentiated=False, policy=policy) return tree_unflatten(out_tree(), out_flat)
def make_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: typing.Tuple[core.AbstractValue], # with DBIdx in them keep_inputs: typing.Tuple[bool] ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]: flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, "dex_jit") jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug, keep_inputs=keep_inputs) jaxpr = pe.convert_constvars_jaxpr(jaxpr_) consts = [_canonicalize_arg(c) for c in consts] return jaxpr, consts, out_tree()
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): assert not jaxpr.constvars in_nonzeros = [type(t) is not ad_util.Zero for t in tangents] jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False) nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero] jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr) outs = remat_p.bind( *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp, prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, out_nonzeros)] return out_primals, out_tangents
def custom_jvp_call_jaxpr(fun, jvp, *args): """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive.""" in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) jvp_jaxpr_thunk = pe._memoize( # pylint: disable=protected-access lambda: cd._initial_style_jaxpr(jvp, in_avals * 2)) # pylint: disable=protected-access return cd.custom_jvp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees): in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) fwd_jaxpr_thunk = pe._memoize( lambda: cd._initial_style_jaxpr(fwd, in_avals)) # pylint: disable=protected-access return cd.custom_vjp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees, num_consts=len(consts))
def __call__(self, residual_arg, linear_arg): res_arg, lin_arg = residual_arg, linear_arg _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_transpose_p.bind(*consts, *args_flat, call=closed_call, rule=self.transpose, lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat)
def _close_jaxpr(jaxpr): return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes): jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_) return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args), ct, cts_in_avals, reduce_axes)
def _flat_initial_style_jaxpr(fun: Callable, in_avals): """lax_control_flow._initial_style_jaxpr, but for flat arguments and results.""" jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) closed_jaxpr = ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts
def _flat_initial_style_jaxpr(fun, in_avals): """lax_control_flow._initial_style_jaxpr, but for flat arguments and results.""" jaxpr, out_avals, consts = _instantiated_trace_to_jaxpr(fun, in_avals) return TypedJaxpr(convert_constvars_jaxpr(jaxpr), (), in_avals=_abstractified(consts) + in_avals, out_avals=map(raise_to_shaped, out_avals)), consts