def jet(fun, primals, series): try: order, = set(map(len, series)) except ValueError: msg = "jet terms have inconsistent lengths for different arguments" raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, terms) in enumerate(zip(primals, series)): treedef = tree_structure(x) if not treedef_is_leaf(treedef): raise ValueError( "primal value at position {} is not an array".format(i)) for j, t in enumerate(terms): treedef = tree_structure(t) if not treedef_is_leaf(treedef): raise ValueError( "term {} for argument {} is not an array".format(j, i)) @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} yield tree_flatten(ans) f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_vals, in_tree = tree_flatten((init_val, )) init_avals = tuple(_map(_abstractify, init_vals)) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals) body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals) if not treedef_is_leaf(cond_tree): msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]: msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) if not treedef_children(in_tree) == [body_tree]: msg = "body_fun output pytree structure must match init_val, got {} and {}." raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0])) outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals), cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs)
def bwd(res, cts): jaxpr, in_tree, out_tree, consts = res cts_flat, out_tree_ = tree_flatten((cts,)) if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}') cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat) cts_out = tree_unflatten(in_tree, cts_out) if treedef_is_leaf(in_tree): cts_out = (cts_out,) return cts_out
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Trace the conditional function. cond_func takes 0 arguments, but # for lax.while we need a conditional function that takes the # carried_state_names. _initial_style_jaxpr will start its own trace and # will create tracers for all the carried state. We must put these values # in the scope._mutable_state before we trace the conditional # function. def cond_func_wrapped(*args): assert len(args) == len(carried_state_names) for ms, init_ms in zip(carried_state_names, args): scope._mutable_state[ms] = init_ms res = self.cond_func() # Conditional function is not allowed to modify the scope state for ms, init_ms in zip(carried_state_names, args): if not (scope._mutable_state[ms] is init_ms): raise ValueError( f"Conditional function modifies scope.{ms} field.") return res init_avals = safe_map(_BodyTracer.abstractify, init_vals) cond_jaxpr, cond_consts, cond_tree = ( lax_control_flow._initial_style_jaxpr(cond_func_wrapped, carried_tree, tuple(init_avals))) # TODO: share these checks with lax_control_flow.while if not tree_util.treedef_is_leaf(cond_tree): raise TypeError( f"cond_fun must return a boolean scalar, but got pytree {cond_tree}." ) if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]): raise TypeError( f"cond_fun must return a boolean scalar, but got output type(s) " f"{cond_jaxpr.out_avals}.") return lax_control_flow.while_p.bind(*cond_consts, *body_const_vals, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_const_vals), body_jaxpr=body_closed_jaxpr)
def is_leaf(x): return tree_util.treedef_is_leaf(tree_util.tree_flatten(x)[1])
def jet(fun, primals, series): r"""Taylor-mode higher-order automatic differentiation. Args: fun: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals: The primal values at which the Taylor approximation of ``fun`` should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of ``fun``. series: Higher order Taylor-series-coefficients. Together, `primals` and `series` make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial. Returns: A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``, and together, ``primals_out`` and ``series_out`` are a truncated Taylor polynomial of :math:`f(h(\cdot))`. The ``primals_out`` value has the same Python tree structure as ``primals``, and the ``series_out`` value the same Python tree structure as ``series``. For example: >>> import jax >>> import jax.numpy as np Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`, and the first few Taylor coefficients :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`. Let :math:`f(y) = \sin(y)`. >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args) :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)` according to Faà di Bruno's formula: >>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473 >>> print(f1, df(h0) * h1) 0.7441479 0.74414825 >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634 """ try: order, = set(map(len, series)) except ValueError: msg = "jet terms have inconsistent lengths for different arguments" raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, terms) in enumerate(zip(primals, series)): treedef = tree_structure(x) if not treedef_is_leaf(treedef): raise ValueError( "primal value at position {} is not an array".format(i)) for j, t in enumerate(terms): treedef = tree_structure(t) if not treedef_is_leaf(treedef): raise ValueError( "term {} for argument {} is not an array".format(j, i)) @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} yield tree_flatten(ans) f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)