def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) x_aval = _demote_aval_rank(xs_aval) x_pval = pe.PartialVal((x_aval, core.unit)) jaxpr, pval_out, consts = pe.trace_to_jaxpr(f, (carry_pval, x_pval), instantiate=True) pv_out, const_out = pval_out assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2: msg = ( "scanned function must have signature `c -> a -> (c, b)`, but the " "output was not a pair: got type {}.") raise TypeError(msg.format(pv_out)) carry_aval_out, y_aval = pv_out if carry_aval != carry_aval_out: msg = ("scanned function carry output does not match carry input: " "input carry is {} and output carry is {}.") raise TypeError(msg.format(carry_aval, carry_aval_out)) lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr) consts_aval, _ = _abstractify(core.pack(consts)) in_avals = (consts_aval, carry_aval, x_aval) out_aval = core.AbstractTuple((carry_aval, y_aval)) jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval) length = _leading_dim_size(xs) out = scan_p.bind(core.pack(consts), init, xs, forward=True, length=length, jaxpr=jaxpr) return build_tree(out_tree(), out)
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive, PropagationRule], jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell], outcells: List[Cell]) -> Environment: """Propagates cells in a Jaxpr using a set of rules. Args: cell_type: used to instantiate literals into cells rules: maps JAX primitives to propagation rule functions jaxpr: used to construct the propagation graph constcells: used to populate the Jaxpr's constvars incells: used to populate the Jaxpr's invars outcells: used to populate the Jaxpr's outcells Returns: The Jaxpr environment after propagation has terminated """ env = Environment(cell_type, jaxpr) safe_map(env.write, jaxpr.constvars, constcells) safe_map(env.write, jaxpr.invars, incells) safe_map(env.write, jaxpr.outvars, outcells) eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns) get_neighbor_eqns = construct_graph_representation(eqns) # Initialize propagation queue with equations neighboring constvars, invars, # and outvars. out_eqns = set() for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars): out_eqns.update(get_neighbor_eqns(var)) queue = collections.deque(out_eqns) done_eqns = set() # checked_eqns is used to stop propagation if all equations in queue have # been checked without the propagation progressing checked_eqns = set() while queue: eqn = queue.popleft() assert eqn not in done_eqns incells = safe_map(env.read, eqn.invars) outcells = safe_map(env.read, eqn.outvars) rule = rules[eqn.primitive] call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init( functools.partial(propagate, cell_type, rules, call_jaxpr, ())) ] else: subfuns = [] new_incells, new_outcells, done, subenv = rule(subfuns + incells, outcells, **params) if subenv: env.write_subenv(eqn, subenv) safe_map(env.write, eqn.invars, new_incells) safe_map(env.write, eqn.outvars, new_outcells) update_queue_state(queue, eqn, get_neighbor_eqns, done, done_eqns, checked_eqns, incells, outcells, new_incells, new_outcells) if checked_eqns == set(queue): break return env
def jet(fun, primals, series): r"""Taylor-mode higher-order automatic differentiation. Args: fun: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals: The primal values at which the Taylor approximation of ``fun`` should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of ``fun``. series: Higher order Taylor-series-coefficients. Together, `primals` and `series` make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial. Returns: A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``, and together, ``primals_out`` and ``series_out`` are a truncated Taylor polynomial of :math:`f(h(\cdot))`. The ``primals_out`` value has the same Python tree structure as ``primals``, and the ``series_out`` value the same Python tree structure as ``series``. For example: >>> import jax >>> import jax.numpy as np Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`, and the first few Taylor coefficients :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`. Let :math:`f(y) = \sin(y)`. >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args) :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)` according to Faà di Bruno's formula: >>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473 >>> print(f1, df(h0) * h1) 0.7441479 0.74414825 >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634 """ try: order, = set(map(len, series)) except ValueError: msg = "jet terms have inconsistent lengths for different arguments" raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, terms) in enumerate(zip(primals, series)): treedef = tree_structure(x) if not treedef_is_leaf(treedef): raise ValueError( "primal value at position {} is not an array".format(i)) for j, t in enumerate(terms): treedef = tree_structure(t) if not treedef_is_leaf(treedef): raise ValueError( "term {} for argument {} is not an array".format(j, i)) @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} yield tree_flatten(ans) f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
def propagate(cell_type: Type[Cell], rules: Dict[jax_core.Primitive, PropagationRule], jaxpr: pe.Jaxpr, constcells: List[Cell], incells: List[Cell], outcells: List[Cell], reducer: Callable[[Environment, Equation, State, State], State] = identity_reducer, initial_state: State = None) -> Tuple[Environment, State]: """Propagates cells in a Jaxpr using a set of rules. Args: cell_type: used to instantiate literals into cells rules: maps JAX primitives to propagation rule functions jaxpr: used to construct the propagation graph constcells: used to populate the Jaxpr's constvars incells: used to populate the Jaxpr's invars outcells: used to populate the Jaxpr's outcells reducer: An optional callable used to reduce over the state at each equation in the Jaxpr. `reducer` takes in `(env, eqn, state, new_state)` as arguments and should return an updated state. The `new_state` value is provided by each equation. initial_state: The initial `state` value used in the reducer Returns: The Jaxpr environment after propagation has terminated """ env = Environment(cell_type, jaxpr) safe_map(env.write, jaxpr.constvars, constcells) safe_map(env.write, jaxpr.invars, incells) safe_map(env.write, jaxpr.outvars, outcells) eqns = safe_map(Equation.from_jaxpr_eqn, jaxpr.eqns) get_neighbor_eqns = construct_graph_representation(eqns) # Initialize propagation queue with equations neighboring constvars, invars, # and outvars. out_eqns = set() for eqn in jaxpr.eqns: for var in it.chain(eqn.invars, eqn.outvars): env.write(var, cell_type.unknown(var.aval)) for var in it.chain(jaxpr.outvars, jaxpr.invars, jaxpr.constvars): out_eqns.update(get_neighbor_eqns(var)) queue = collections.deque(out_eqns) while queue: eqn = queue.popleft() incells = safe_map(env.read, eqn.invars) outcells = safe_map(env.read, eqn.outvars) call_jaxpr, params = jax_core.extract_call_jaxpr( eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init( functools.partial(propagate, cell_type, rules, call_jaxpr, (), initial_state=initial_state, reducer=reducer)) ] if eqn.primitive not in rules: rule = default_call_rules.get(eqn.primitive) else: rule = rules[eqn.primitive] else: subfuns = [] rule = rules[eqn.primitive] new_incells, new_outcells, eqn_state = rule(subfuns + incells, outcells, **params) env.write_state(eqn, eqn_state) new_incells = safe_map(env.write, eqn.invars, new_incells) new_outcells = safe_map(env.write, eqn.outvars, new_outcells) update_queue_state(queue, eqn, get_neighbor_eqns, incells, outcells, new_incells, new_outcells) state = initial_state for eqn in eqns: state = reducer(env, eqn, state, env.read_state(eqn)) return env, state
def linear_call(fun: Callable, fun_transpose: Callable, residual_args, linear_args): """Call a linear function, with a custom implementation for its transpose. The type signatures of ``fun`` and ``fun_transpose`` are: .. code-block:: haskell fun :: r -> a -o b fun_transpose :: r -> b -o a where the ``-o`` arrow indicates a linear function, ``r`` is the residual input type and ``a`` is the linear input type. The functions ``fun`` and ``fun_transpose`` are coupled as transposes of one another. Specifically, the transpose of a ``linear_call`` primitive is another ``linear_call`` to ``fun_transpose``, with ``fun`` as its custom transposition. For example: >>> def f(r, x): ... return x / r >>> def t(r, t): ... return t / r >>> def div_add(x, denom): ... return x + linear_call(f, t, denom, x) >>> def transpose(f, x_example): ... def transposed(y): ... x, = jax.linear_transpose(f, x_example)(y) ... return x ... return transposed >>> div_add(9., 3.) DeviceArray(12., dtype=float32) >>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom DeviceArray(24., dtype=float32) >>> transpose(lambda x: x + x / 3., 1.)(18.) # reference DeviceArray(24., dtype=float32) The above definition of ``f`` illustrates the purpose of a residual argument: division is linear in one of its inputs (the dividend ``x``) but not the other (the divisor ``r``). As another example: >>> def custom_id(x): ... def f(_, x): return x ... def t(_, t): return 7. ... return linear_call(f, t, (), x) >>> custom_id(1.) 1.0 >>> transpose(custom_id, 1.)(1.) 7.0 >>> transpose(transpose(custom_id, 1.), 1.)(1.) 1.0 >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.) 7.0 Args: fun: a Python callable specifying a linear function. It should take two arguments: one of "residual" inputs (type ``r``), i.e. inputs in which the function is not necessarly linear, and one of "linear" inputs (type ``a``). It should return output whose components are linear in the linear input (type ``b``). fun_transpose: a Python callable specifying a structurally linear function that is the transpose of ``fun`` with respect to its linear inputs. Its first argument is the same residual inputs (``r``) as ``fun``. Its second argument is of type ``b``. Finally, its output is of type ``a`` and each of its component are linear in its second argument (the ``b`` inputs). residual_args: Argument in which ``fun`` and ``fun_transpose`` are not necessarily linear. Not involved in transposition. linear_args: Argument in which ``fun`` and ``fun_transpose`` are linear and with respect to which the two are transposes. Returns: The call result, i.e. ``fun(residual_args, linear_args)``. """ operands_res, res_tree = tree_flatten(residual_args) operands_lin, lin_tree = tree_flatten(linear_args) f_in_tree = treedef_tuple((res_tree, lin_tree)) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree) res_avals = map(abstractify, operands_res) lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) t_jaxpr = _close_jaxpr(t_jaxpr) if t_out_tree() != lin_tree: raise TypeError( 'transpose output pytree structure must match that of linear inputs, ' f'got output structure {t_out_tree()} ' f'and input structure {lin_tree}.') out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, callee=f_jaxpr, transpose=t_jaxpr, num_callee_consts=len(f_consts), num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out)
def doit(): f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat)
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): if all(type(ct) is ad.Zero for ct in cotangents_in): return map(lambda v: ad.Zero(v.aval), jaxpr.invars) def write_cotangent(v, ct): # assert v not in primal_env if ct is not None and type(v) is not Literal: ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct def read_cotangent(v): return ct_env.get(v, ad.Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, ad.UndefinedPrimal(v.aval)) def write_primal(v, val): if type(v) is Literal: return primal_env.setdefault(v, val) # Invert while computing cotangents ct_env: Dict[Any, Any] = {} primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.invars, primals_in) map(write_primal, jaxpr.outvars, primals_out) map(write_primal, jaxpr.constvars, consts) map(write_cotangent, jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: primals_in = map(read_primal, eqn.invars) primals_out = map(read_primal, eqn.outvars) cts_in = map(read_cotangent, eqn.outvars) should_invert = any(type(primal) is not ad.UndefinedPrimal for primal in primals_out) should_vjp = any(type(ct) is not ad.Zero for ct in cts_in) assert not eqn.primitive.call_primitive # Skip primals equations that are only jvp coefficients and don't affect # primal outputs. if not should_invert and not should_vjp: continue def abstract(value): return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value)) # Get the ivjp_jaxpr if eqn.primitive is custom_ivjp_p: ivjp_jaxpr = eqn.params['ivjp_jaxpr'] else: if eqn.primitive in primitive_ivjps: complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive]) else: complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in))) _, in_tree = tree_flatten( tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree) in_avals = map(abstract, primals_in + primals_out + primals_out) # TODO: Actually we do know some of the inputs, because they might be literals! ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, []) # Once we know what the ivjp can do exactly, we have to isolate the part we are # actually able to compute with the values we have at hand. num_inputs = len(eqn.invars) unknowns = (map(ad.is_undefined_primal, primals_in) + map(ad.is_undefined_primal, primals_out) + [False] * len(cts_in)) jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore ivjp_jaxpr, unknowns, instantiate=False) # type:ignore unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) # Make sure we're able to compute all cotangents. We don't really care if we # can reconstruct or primals or not, although failure to do so might result in # failing to compute cotangents later. assert not any(unknown_cotangents) # Remove residual outputs -- we won't be computing the unknown jaxpr anyway. num_outputs = len(jaxpr_unknown.jaxpr.outvars) jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] # TODO: We could drop the outputs that correspond to primals that we already know. # This only matters in eager mode, so leaving it out for now... ivjp = core.jaxpr_as_fun(jaxpr_known) rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), [num_inputs]) # Unknown rec_primals_in are core.units, so we have to replace them # with UnknownPrimals because that's what write_primal will ignore. rec_primals_in = [prev if unknown else rec for prev, rec, unknown in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] map(write_primal, eqn.invars, rec_primals_in) map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out) # NOTE: We keep the cotangents associated with primal variables, while the contract of a # transpose is to return them in positions associated with tangent variables, which # is what causes this whole confusion. return map(read_cotangent, jaxpr.invars)
def _batch(args, dims, **params): batched, out_dims = batching.batch_fun2( lu.wrap_init(self.impl, params), dims) return batched.call_wrapped(*args), out_dims()
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense spvalues: Sequence[ SparsifyValue], # mix of sparse and dense pointers into spenv spenv: SparsifyEnv, ) -> Sequence[SparsifyValue]: env: Dict[core.Var, SparsifyValue] = {} def read(var: core.Var) -> Union[Array, SparsifyValue]: # all literals are dense if isinstance(var, core.Literal): return spenv.dense(var.val) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if isinstance(var, core.DropVar): return env[var] = spenv.dense(a) def write(var: core.Var, a: SparsifyValue) -> None: if isinstance(var, core.DropVar): return assert a is not None env[var] = a # TODO: handle unitvar at all? #write_buffer(core.unitvar, core.unit) safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, spvalues) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: raise NotImplementedError(f"sparse rule for {prim}") out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init( core.jaxpr_as_fun( pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) else: out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): if isinstance(outvar, core.DropVar): out.append(None) else: out.append(spenv.dense(buf)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def f_(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False ) return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
def core_closed_call(f, *args): args, in_tree = tree_flatten(args) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree) out = core.closed_call_p.bind(f, *args) return tree_unflatten(out_tree(), out)
def submerge_consts(jaxpr, consts, invals=None): """ Replace constvars with literals in jaxpr and its sub-jaxprs. """ # TODO(j-towns): check that consts are in jax.core.literalable_types consts = dict(zip(jaxpr.constvars, consts)) if invals is not None: # We're in a call_jaxpr new_jaxpr_invars = [] for var, val in zip(jaxpr.invars, invals): if isinstance(val, Var): new_jaxpr_invars.append(var) else: consts[var] = val else: new_jaxpr_invars = jaxpr.invars new_eqns = [] for eqn in jaxpr.eqns: if all( isinstance(var, Literal) or var in consts for var in eqn.invars): # Perform constant folding if all inputs to an eqn are known in_vals = [ var.val if isinstance(var, Literal) else consts[var] for var in eqn.invars ] call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: subfuns = [ lu.wrap_init(partial(jc.eval_jaxpr, call_jaxpr, ())) ] else: subfuns = [] ans = eqn.primitive.bind(*(subfuns + in_vals), **params) if eqn.primitive.multiple_results: for outvar, out in zip(eqn.outvars, ans): consts[outvar] = out else: outvar, = eqn.outvars consts[outvar] = ans else: new_invars = [ consts[var] if (isinstance(var, Var) and var in consts) else var for var in eqn.invars ] new_params = dict(eqn.params) if eqn.primitive.call_primitive or eqn.primitive.map_primitive: new_params['call_jaxpr'] = submerge_consts( eqn.params['call_jaxpr'], [], new_invars) new_invars = [ var for var in new_invars if isinstance(var, Var) ] else: new_invars = [ var if isinstance(var, (Var, Literal)) else Literal_(var) for var in new_invars ] new_eqns.append( JaxprEqn(invars=new_invars, outvars=eqn.outvars, primitive=eqn.primitive, params=new_params, source_info=eqn.source_info)) return Jaxpr([], new_jaxpr_invars, jaxpr.outvars, new_eqns)
def jvp_jaxpr(jaxpr): f = lu.wrap_init(jaxpr_as_fun(jaxpr)) dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders) in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders] jaxpr, consts, _ = trace_to_jaxpr_dynamic(jvp_traceable(ad.jvp(f)), in_avals * 2) return jaxpr, consts
def make_djaxpr(fun, *args, **kwargs): args, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] return trace_to_jaxpr_dynamic(f, in_avals)
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr( flat_cond_fun, (carry_pval_flat, )) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (carry_pval_flat, ), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def _jvp(primals, tangents, **params): return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped( primals, tangents)
def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): in_batched_ps, in_batched_ts = in_batched mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts) extra_batched_ps = tree_map( lambda pb, tb: 0 if pb and not tb else None, in_batched_ps, in_batched_ts) extra_batched_ts = tree_map( lambda pb, tb: 0 if tb and not pb else None, in_batched_ps, in_batched_ts) out_mutually_batched = lu.Store() flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents)) flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten( (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None) # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) # once https://github.com/google/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 def to_jvp(*primals): outs, out_batched = call_rule(rule, axis_size, mutually_batched, primals) check_vmap_rule_trees(rule, tree_structure(outs), tree_structure(out_batched)) out_mutually_batched.store(out_batched) return outs def to_vmap_over_extra_batched_dims(primals, tangents): return jax.jvp(to_jvp, primals, tangents) to_vmap_over_extra_batched_dims_flat, out_tree = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts) flat_out_ps_ts, flat_out_axes = vmap_unrestricted( to_vmap_over_extra_batched_dims_flat, *flat_ps_ts, in_axes=flat_extra_batched_ps_ts, axis_name=core.no_axis_name, axis_size=axis_size) n, ragged = divmod(len(flat_out_ps_ts), 2) assert not ragged flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:] flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:] flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p) flat_out_extra_batched_ps = [ d is not not_mapped for d in flat_out_axes_p ] flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t) flat_out_extra_batched_ts = [ d is not not_mapped for d in flat_out_axes_t ] out_ps, out_ts = tree_unflatten(out_tree(), [*flat_out_ps, *flat_out_ts]) out_extra_batched_ps, out_extra_batched_ts = tree_unflatten( out_tree(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) out_batched_ps = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ps) out_batched_ts = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ts) return (out_ps, out_ts), (out_batched_ps, out_batched_ts)
def _batch(args, dims, **params): return batching.batch_fun(lu.wrap_init(self.impl, params), args, dims)
def _interpret_jaxpr(jaxpr: core.TypedJaxpr, *args: TfVal) -> Sequence[TfVal]: """Evaluate a Jaxpr with tf.Tensor arguments.""" fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) out_vals = _interpret_fun(fun, args) return out_vals
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = jax.tree_flatten(args) flat_fun, out_tree = jax.flatten_fun_nokwargs(fun, in_tree) ans = jax_core.call_p.bind(flat_fun, *flat_args) return jax.tree_unflatten(out_tree(), ans)
def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) carry_pval_flat = carry_aval, _ = _abstractify(init_val_flat) cond_jaxpr, cond_pval_out, cond_consts = pe.trace_to_jaxpr( flat_cond_fun, (carry_pval_flat, )) body_jaxpr, body_pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (carry_pval_flat, ), instantiate=True) carry_aval_out, _ = body_pval_out assert isinstance(carry_aval_out, core.AbstractValue) assert carry_aval == core.lattice_join(carry_aval, carry_aval_out) cond_pv, cond_const = cond_pval_out if cond_pv is None: # cond_fun evaluates to a constant, so don't need to generate a while_loop if cond_const: raise ValueError("infinite loop with no effects") else: return init_val else: assert isinstance(cond_pv, core.AbstractValue) if (not isinstance(cond_pv, ShapedArray) or cond_pv.shape or cond_pv.dtype != onp.bool_): msg = "while_loop cond_fun must return a scalar boolean, got {}." raise TypeError(msg.format(cond_pv)) if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind(init_val_flat, core.pack(cond_consts), core.pack(body_consts), aval_out=carry_aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def checked_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) (err, code, out_flat), msgs = checkify_flat(f, errors, *args_flat) out = tree_unflatten(out_tree(), out_flat) return Error(err, code, msgs), out
def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree, kwargs, **params): """Batching rule for layer_cau primitive to handle custom layers.""" if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(*vals, num_consts=num_consts, in_tree=in_tree, out_tree=out_tree, kwargs=kwargs, **params) orig_vals, orig_dims = vals, dims vals, dims = vals[num_consts:], dims[num_consts:] args = tree_util.tree_unflatten(in_tree, vals) dims_ = [not_mapped if dim is None else dim for dim in dims] layer, args = args[0], args[1:] if hasattr(layer, '_call_and_update_batched'): num_params = len(tree_util.tree_leaves(layer)) layer_dims, arg_dims = dims_[:num_params], dims_[num_params:] if kwargs['has_rng']: rng, args = args[0], args[1:] rng_dim, arg_dims = arg_dims[0], arg_dims[1:] mapping_over_layer = all(layer_dim is not not_mapped for layer_dim in layer_dims) mapping_over_args = all(arg_dim is not not_mapped for arg_dim in arg_dims) assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims) if not mapping_over_layer and mapping_over_args: if kwargs['has_rng']: if rng_dim is not not_mapped: arg_dims = tuple(None if dim is not_mapped else dim for dim in arg_dims) map_fun = jax.vmap( lambda layer, rng, *args: _layer_cau_batched( layer, rng, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **kwargs), in_axes=(None, rng_dim) + (None, ) * len(arg_dims)) else: map_fun = lambda layer, *args: _layer_cau_batched( layer, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **kwargs) vals_out, update_out = map_fun(layer, rng, *args) else: vals_out, update_out = _layer_cau_batched( layer, *args, **kwargs) vals_out = tree_util.tree_leaves(vals_out) update_out = tree_util.tree_leaves(update_out) assert all(dim == 0 for dim in arg_dims) # Assume dimensions out are consistent dims_out = (0, ) * len(vals_out) dims_update = (None, ) * len(update_out) assert len(vals_out) == len(dims_out) assert len(update_out) == len(dims_update) return vals_out + update_out, dims_out + dims_update batched, out_dims = primitive.batch_fun( lu.wrap_init( layer_cau_p.impl, dict(params, num_consts=num_consts, in_tree=in_tree, out_tree=out_tree, kwargs=kwargs)), orig_dims) return batched.call_wrapped(*orig_vals), out_dims()
def wrapped_fun(*args): args_flat, in_tree = tree_flatten(args) f = lu.wrap_init(fun) flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls) return tree_unflatten(out_tree(), out_flat)
def _wrapped(*args): args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) out = sparsify_fun(wrapped_fun, args_flat) return tree_unflatten(out_tree(), out)
def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
def trace_jaxpr(fun, operand): op_flat, in_tree = pytree_to_flatjaxtuple(operand) fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,)) jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),)) return op_flat, jaxpr, consts, pvout, out_tree