Esempio n. 1
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_vals, in_tree = tree_flatten((init_val, ))
    init_avals = tuple(_map(_abstractify, init_vals))
    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
        cond_fun, in_tree, init_avals)
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
        body_fun, in_tree, init_avals)
    if not treedef_is_leaf(cond_tree):
        msg = "cond_fun must return a boolean scalar, but got pytree {}."
        raise TypeError(msg.format(cond_tree))
    if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
        msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
        raise TypeError(msg.format(cond_jaxpr.out_avals))
    if not treedef_children(in_tree) == [body_tree]:
        msg = "body_fun output pytree structure must match init_val, got {} and {}."
        raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
    outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
                        cond_nconsts=len(cond_consts),
                        cond_jaxpr=cond_jaxpr,
                        body_nconsts=len(body_consts),
                        body_jaxpr=body_jaxpr)
    return tree_unflatten(body_tree, outs)
Esempio n. 2
0
def root(f, initial_guess, solve, tangent_solve):
  """Differentiably solve for a roots of a function.

  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of root() are defined with respect to closed-over variables from
  the provided function f.

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::

        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)

    tangent_solve: function to solve the tangent system. Should take two
      positional arguments, a linear function ``g`` (the function ``f``
      linearized at its root) and a tree of array(s) ``y`` with the same
      structure as initial_guess, and return a solution ``x`` such that
      ``g(x)=y``:

      - For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
      - For vector ``y``, you could use a linear solve with the Jacobian, if
        dimensionality of ``y`` is not too large:
        ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.

  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """
  guess_flat, in_args_tree = tree_flatten((initial_guess,))
  guess_avals = tuple(_map(_abstractify, guess_flat))
  jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)

  in_tree, = treedef_children(in_args_tree)
  if in_tree != out_tree:
    raise TypeError(_root_tree_error_template("f").format(out_tree, in_tree))

  solve_flat = _flatten_higher_order_func(
      solve, in_tree, _root_tree_error_template("solve"))
  tangent_solve_flat = _flatten_higher_order_func(
      tangent_solve, in_tree, _root_tree_error_template("tangent_solve"))

  out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
                         num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat,
                         tangent_solve=tangent_solve_flat)
  return tree_unflatten(out_tree, out_flat)
Esempio n. 3
0
def custom_linear_solve(
    matvec, b, solve, transpose_solve=None, symmetric=False):
  """Perform a matrix-free linear solve with implicitly defined gradients.

  This function allows for overriding or defining gradients for a linear
  solve directly via implicit differentiation at the solution, rather than by
  differenting *through* the solve operation. This can sometimes be much faster
  or more numerically stable, or differentiating through the solve operation
  may not even be implemented (e.g., if ``solve`` using ``lax.while_loop``).

  Required invariant:
      x = solve(matvec, b)  # solve the linear equation
      assert matvec(x) == b  # not checked

  Args:
    matvec: linear function to invert. Must be differentiable.
    b: constant right handle side of the equation. May be any nested structure
      of arrays.
    solve: higher level function that solves for solution to the linear
      equation, i.e., ``matvec(solve(matvec, x)) == x`` for all ``x`` of the
      same form as ``b``. This function need not be differenatiable.
    transpose_solve: higher level function for solving the transpose linear
      equation, i.e., ``vecmat(transpose_solve(vecmat, x)) == x``, where
      ``vecmat`` is the transpose of the linear map ``matvec`` (computed
      automatically with autodiff). Required for backwards mode automatic
      differentiation, unless ``symmetric=True``, in which case ``solve``
      provides the default value.
    symmetric: bool indicating if it is safe to assume the linear map
      corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.

  Returns:
    Result of ``solve(matvec, b)``, with gradients defined assuming that the
    solution ``x`` satisfies the linear equation ``matvec(x) == b``.
  """
  if transpose_solve is None and symmetric:
    transpose_solve = solve

  b_flat, in_args_tree = tree_flatten((b,))
  b_avals = tuple(_map(_abstractify, b_flat))
  matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
      matvec, in_args_tree, b_avals)

  tree, = treedef_children(in_args_tree)
  _check_tree("matvec", "b", out_tree, tree)

  solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
      partial(solve, matvec), in_args_tree, b_avals)
  _check_tree("solve", "b", out_tree, tree)

  if transpose_solve is None:
    vecmat_jaxpr = tr_solve_jaxpr = None
    vecmat_consts = tr_solve_consts = []
  else:
    if symmetric:
      vecmat = matvec
      vecmat_jaxpr = matvec_jaxpr
      vecmat_consts = matvec_consts
    else:
      vecmat = _transpose_function(matvec, b)
      vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
          vecmat, in_args_tree, b_avals)
      assert out_tree == tree

    tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
        partial(transpose_solve, vecmat), in_args_tree, b_avals)
    _check_tree("transpose_solve", "b", out_tree, tree)

  all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
  const_lengths = _LinearSolveTuple(*_map(len, all_consts))
  jaxprs = _LinearSolveTuple(
      matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)

  out_flat = custom_linear_solve_p.bind(
      *(_flatten(all_consts) + b_flat),
      const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
  return tree_unflatten(tree, out_flat)
Esempio n. 4
0
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
    """Differentiably solve for a roots of a function.

  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of custom_root() are defined with respect to closed-over variables
  from the provided function ``f`` via the implicit function theorem:
  https://en.wikipedia.org/wiki/Implicit_function_theorem

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::

        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)

    tangent_solve: function to solve the tangent system. Should take two
      positional arguments, a linear function ``g`` (the function ``f``
      linearized at its root) and a tree of array(s) ``y`` with the same
      structure as initial_guess, and return a solution ``x`` such that
      ``g(x)=y``:

      - For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
      - For vector ``y``, you could use a linear solve with the Jacobian, if
        dimensionality of ``y`` is not too large:
        ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
    has_aux: bool indicating whether the ``solve`` function returns
      auxiliary data like solver diagnostics as a second argument.

  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """
    guess_flat, in_args_tree = tree_flatten((initial_guess, ))
    guess_avals = tuple(_map(_abstractify, guess_flat))
    f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(f, in_args_tree,
                                                       guess_avals)

    in_tree, = treedef_children(in_args_tree)
    _check_tree("f", "initial_guess", out_tree, in_tree, False)

    solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
        partial(solve, f), in_args_tree, guess_avals)
    _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)

    def linearize_and_solve(x, b):
        unchecked_zeros, f_jvp = jax.linearize(f, x)
        return tangent_solve(f_jvp, b)

    l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
        linearize_and_solve, treedef_tuple((in_tree, ) * 2), guess_avals * 2)
    _check_tree("tangent_solve", "x", out_tree, in_tree, False)

    all_consts = [f_consts, solve_consts, l_and_s_consts]
    const_lengths = _RootTuple(*_map(len, all_consts))
    jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)

    solution_flat = _custom_root(const_lengths, jaxprs,
                                 *(_flatten(all_consts) + guess_flat))
    return tree_unflatten(solution_tree, solution_flat)