def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_vals, in_tree = tree_flatten((init_val, )) init_avals = tuple(_map(_abstractify, init_vals)) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals) body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals) if not treedef_is_leaf(cond_tree): msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]: msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) if not treedef_children(in_tree) == [body_tree]: msg = "body_fun output pytree structure must match init_val, got {} and {}." raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0])) outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals), cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs)
def root(f, initial_guess, solve, tangent_solve): """Differentiably solve for a roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of root() are defined with respect to closed-over variables from the provided function f. Args: f: function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input. initial_guess: initial guess for a zero of f. solve: function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):: solution = solve(f, initial_guess) error = f(solution) assert all(error == 0) tangent_solve: function to solve the tangent system. Should take two positional arguments, a linear function ``g`` (the function ``f`` linearized at its root) and a tree of array(s) ``y`` with the same structure as initial_guess, and return a solution ``x`` such that ``g(x)=y``: - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. - For vector ``y``, you could use a linear solve with the Jacobian, if dimensionality of ``y`` is not too large: ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. Returns: The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(_abstractify, guess_flat)) jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals) in_tree, = treedef_children(in_args_tree) if in_tree != out_tree: raise TypeError(_root_tree_error_template("f").format(out_tree, in_tree)) solve_flat = _flatten_higher_order_func( solve, in_tree, _root_tree_error_template("solve")) tangent_solve_flat = _flatten_higher_order_func( tangent_solve, in_tree, _root_tree_error_template("tangent_solve")) out_flat = root_p.bind(*itertools.chain(consts, guess_flat), num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat, tangent_solve=tangent_solve_flat) return tree_unflatten(out_tree, out_flat)
def custom_linear_solve( matvec, b, solve, transpose_solve=None, symmetric=False): """Perform a matrix-free linear solve with implicitly defined gradients. This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differenting *through* the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if ``solve`` using ``lax.while_loop``). Required invariant: x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked Args: matvec: linear function to invert. Must be differentiable. b: constant right handle side of the equation. May be any nested structure of arrays. solve: higher level function that solves for solution to the linear equation, i.e., ``matvec(solve(matvec, x)) == x`` for all ``x`` of the same form as ``b``. This function need not be differenatiable. transpose_solve: higher level function for solving the transpose linear equation, i.e., ``vecmat(transpose_solve(vecmat, x)) == x``, where ``vecmat`` is the transpose of the linear map ``matvec`` (computed automatically with autodiff). Required for backwards mode automatic differentiation, unless ``symmetric=True``, in which case ``solve`` provides the default value. symmetric: bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e., ``matvec == vecmat``. Returns: Result of ``solve(matvec, b)``, with gradients defined assuming that the solution ``x`` satisfies the linear equation ``matvec(x) == b``. """ if transpose_solve is None and symmetric: transpose_solve = solve b_flat, in_args_tree = tree_flatten((b,)) b_avals = tuple(_map(_abstractify, b_flat)) matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( matvec, in_args_tree, b_avals) tree, = treedef_children(in_args_tree) _check_tree("matvec", "b", out_tree, tree) solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( partial(solve, matvec), in_args_tree, b_avals) _check_tree("solve", "b", out_tree, tree) if transpose_solve is None: vecmat_jaxpr = tr_solve_jaxpr = None vecmat_consts = tr_solve_consts = [] else: if symmetric: vecmat = matvec vecmat_jaxpr = matvec_jaxpr vecmat_consts = matvec_consts else: vecmat = _transpose_function(matvec, b) vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( vecmat, in_args_tree, b_avals) assert out_tree == tree tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( partial(transpose_solve, vecmat), in_args_tree, b_avals) _check_tree("transpose_solve", "b", out_tree, tree) all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] const_lengths = _LinearSolveTuple(*_map(len, all_consts)) jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) out_flat = custom_linear_solve_p.bind( *(_flatten(all_consts) + b_flat), const_lengths=const_lengths, jaxprs=jaxprs, tree=tree) return tree_unflatten(tree, out_flat)
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): """Differentiably solve for a roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function ``f`` via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem Args: f: function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input. initial_guess: initial guess for a zero of f. solve: function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):: solution = solve(f, initial_guess) error = f(solution) assert all(error == 0) tangent_solve: function to solve the tangent system. Should take two positional arguments, a linear function ``g`` (the function ``f`` linearized at its root) and a tree of array(s) ``y`` with the same structure as initial_guess, and return a solution ``x`` such that ``g(x)=y``: - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. - For vector ``y``, you could use a linear solve with the Jacobian, if dimensionality of ``y`` is not too large: ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. has_aux: bool indicating whether the ``solve`` function returns auxiliary data like solver diagnostics as a second argument. Returns: The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ guess_flat, in_args_tree = tree_flatten((initial_guess, )) guess_avals = tuple(_map(_abstractify, guess_flat)) f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals) in_tree, = treedef_children(in_args_tree) _check_tree("f", "initial_guess", out_tree, in_tree, False) solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( partial(solve, f), in_args_tree, guess_avals) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): unchecked_zeros, f_jvp = jax.linearize(f, x) return tangent_solve(f_jvp, b) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( linearize_and_solve, treedef_tuple((in_tree, ) * 2), guess_avals * 2) _check_tree("tangent_solve", "x", out_tree, in_tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] const_lengths = _RootTuple(*_map(len, all_consts)) jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) solution_flat = _custom_root(const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) return tree_unflatten(solution_tree, solution_flat)