def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def rearrange_binders(f, typed_jaxpr): jaxpr = typed_jaxpr.jaxpr.copy() jaxpr.invars = f(*jaxpr.invars) in_avals = f(*typed_jaxpr.in_avals) core.skip_checks or core.check_jaxpr(jaxpr) return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, typed_jaxpr.out_aval)
def _make_typed_jaxpr(traceable, in_avals): pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) out_avals, _ = unzip2(pvals_out) return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
def _make_typed_jaxpr(traceable, in_avals): pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] jaxpr, pval_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) out_aval, _ = pval_out assert isinstance(out_aval, core.AbstractValue) return core.TypedJaxpr(jaxpr, consts, in_avals, out_aval)
def _initial_style_jaxpr(fun, in_tree, in_avals): in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True) out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts, out_tree()
def _rewrite_typed_jaxpr( tjaxpr: core.TypedJaxpr, has_input_token: bool, has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]: """Rewrites a TypedJaxpr to thread the token, if needed. Returns the rewritten Jaxpr, and whether it uses outfeed.""" new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, has_output_token) return (core.TypedJaxpr(new_jaxpr, tjaxpr.literals, tuple(map(lambda v: v.aval, new_jaxpr.invars)), tuple(map(lambda v: v.aval, new_jaxpr.outvars))), uses_outfeed)
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals] jaxpr, out_pvals, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True, stage_out=True, trace_type=pe.StagingJaxprTrace) out_avals = [pval.get_aval() for pval in out_pvals] typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals) return typed_jaxpr, (in_tree, out_tree())
def tie_the_knot(typed_jaxpr): jaxpr, _, in_avals, out_avals = typed_jaxpr assert all(i == o for i, o in zip(in_avals, out_avals)) in2out = dict(zip(jaxpr.invars, jaxpr.outvars)) def replace(eqn): invars = [ in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i for i in eqn.invars ] return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params, eqn.source_info) eqns = [replace(eqn) for eqn in jaxpr.eqns] new_jaxpr = jc.Jaxpr(jaxpr.constvars, [], jaxpr.outvars, eqns) return jc.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, [], typed_jaxpr.out_avals)
def _move_stuff_and_add_add(typed_jaxpr): # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval assert CTc_aval == CTc_aval2 in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)), core.AbstractTuple((CTb_aval, res_aval))) out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)), CTa_aval)) jaxpr = typed_jaxpr.jaxpr.copy() # assume the jaxpr isn't restructuring any inputs assert not any(type(invar) is tuple for invar in jaxpr.invars) # munge input side CTc_in = _scan_newvar() CTb_in = _scan_newvar() CTd_in = _scan_newvar() res_in, CTc_CTb_in = jaxpr.invars jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in)) jaxpr.eqns = ( [pe._pack_eqn([CTc_in, CTb_in], CTc_CTb_in)] + jaxpr.eqns) # munge output side CTd_new = _scan_newvar() CTd_sum = _scan_newvar() CTc = _scan_newvar() CTa = _scan_newvar() partial_out = _scan_newvar() outvar = _scan_newvar() jaxpr.eqns = ( jaxpr.eqns + [pe._unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]), _add_any_eqn(CTd_sum, CTd_new, CTd_in), pe._pack_eqn([CTc, CTd_sum], partial_out), pe._pack_eqn([partial_out, CTa], outvar)]) jaxpr.outvar = outvar # TODO(mattjj): add a check_typed_jaxpr and use it here core.skip_checks or core.check_jaxpr(jaxpr) return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, out_aval)
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if not jax.config.omnistaging_enabled: raise ValueError('Oryx must be used with JAX omnistaging enabled.') if dynamic: jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_avals) else: pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals] jaxpr, out_pvals, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True) out_avals = [pval.get_aval() for pval in out_pvals] typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals) return typed_jaxpr, (in_tree, out_tree())
def _mk_typed_jaxpr(jaxpr: core.Jaxpr, literals: Sequence) -> core.TypedJaxpr: return core.TypedJaxpr(jaxpr, literals, tuple(map(lambda v: v.aval, jaxpr.invars)), tuple(map(lambda v: v.aval, jaxpr.outvars)))
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): if all(type(ct) is ad.Zero for ct in cotangents_in): return map(lambda v: ad.Zero(v.aval), jaxpr.invars) def write_cotangent(v, ct): # assert v not in primal_env if ct is not None and type(v) is not Literal: ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct def read_cotangent(v): return ct_env.get(v, ad.Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, ad.UndefinedPrimal(v.aval)) def write_primal(v, val): if type(v) is Literal: return primal_env.setdefault(v, val) # Invert while computing cotangents ct_env: Dict[Any, Any] = {} primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.invars, primals_in) map(write_primal, jaxpr.outvars, primals_out) map(write_primal, jaxpr.constvars, consts) map(write_cotangent, jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: primals_in = map(read_primal, eqn.invars) primals_out = map(read_primal, eqn.outvars) cts_in = map(read_cotangent, eqn.outvars) should_invert = any(type(primal) is not ad.UndefinedPrimal for primal in primals_out) should_vjp = any(type(ct) is not ad.Zero for ct in cts_in) assert not eqn.primitive.call_primitive assert not (should_invert ^ should_vjp) # Either both true or both false # Skip primals equations that are only jvp coefficients and don't affect # primal outputs. if not should_invert and not should_vjp: continue def abstract(value): return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value)) # Get the ivjp_jaxpr if eqn.primitive is custom_ivjp_p: ivjp_jaxpr = eqn.params['ivjp_jaxpr'] else: if eqn.primitive in primitive_ivjps: complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive]) else: complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in))) _, in_tree = tree_flatten( tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree) in_avals = map(abstract, primals_in + primals_out + primals_out) if config.omnistaging_enabled: # TODO: Actually we do know some of the inputs, because they might be literals! ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) else: ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True, stage_out=False) assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals) # Once we know what the ivjp can do exactly, we have to isolate the part we are # actually able to compute with the values we have at hand. num_inputs = len(eqn.invars) unknowns = (map(ad.is_undefined_primal, primals_in) + map(ad.is_undefined_primal, primals_out) + [False] * len(cts_in)) if config.omnistaging_enabled: jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore ivjp_jaxpr, unknowns, instantiate=False) # type:ignore else: jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) # Make sure we're able to compute all cotangents. We don't really care if we # can reconstruct or primals or not, although failure to do so might result in # failing to compute cotangents later. assert not any(unknown_cotangents) # Remove residual outputs -- we won't be computing the unknown jaxpr anyway. num_outputs = len(jaxpr_unknown.jaxpr.outvars) jaxpr_known.out_avals = jaxpr_known.out_avals[:num_outputs] jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] # TODO: We could drop the outputs that correspond to primals that we already know. # This only matters in eager mode, so leaving it out for now... ivjp = core.jaxpr_as_fun(jaxpr_known) rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), [num_inputs]) # Unknown rec_primals_in are core.units, so we have to replace them # with UnknownPrimals because that's what write_primal will ignore. rec_primals_in = [prev if unknown else rec for prev, rec, unknown in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] map(write_primal, eqn.invars, rec_primals_in) map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out) # NOTE: We keep the cotangents associated with primal variables, while the contract of a # transpose is to return them in positions associated with tangent variables, which # is what causes this whole confusion. return map(read_cotangent, jaxpr.invars)
def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) x_aval = _demote_aval_rank(xs_aval) x_pval = pe.PartialVal((x_aval, core.unit)) jaxpr, pval_out, consts = pe.trace_to_jaxpr( f, (carry_pval, x_pval), instantiate=True) pv_out, const_out = pval_out assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2: msg = ("scanned function must have signature `c -> a -> (c, b)`, but the " "output was not a pair: got type {}.") raise TypeError(msg.format(pv_out)) carry_aval_out, y_aval = pv_out if carry_aval != carry_aval_out: msg = ("scanned function carry output does not match carry input: " "input carry is {} and output carry is {}.") raise TypeError(msg.format(carry_aval, carry_aval_out)) lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr) consts_aval, _ = _abstractify(core.pack(consts)) in_avals = (consts_aval, carry_aval, x_aval) out_aval = core.AbstractTuple((carry_aval, y_aval)) jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval) length = _leading_dim_size(xs) out = scan_p.bind(core.pack(consts), init, xs, forward=True, length=length, jaxpr=jaxpr) return build_tree(out_tree(), out)