Beispiel #1
0
def jet(fun, primals, series):
    try:
        order, = set(map(len, series))
    except ValueError:
        msg = "jet terms have inconsistent lengths for different arguments"
        raise ValueError(msg) from None

    # TODO(mattjj): consider supporting pytree inputs
    for i, (x, terms) in enumerate(zip(primals, series)):
        treedef = tree_structure(x)
        if not treedef_is_leaf(treedef):
            raise ValueError(
                "primal value at position {} is not an array".format(i))
        for j, t in enumerate(terms):
            treedef = tree_structure(t)
            if not treedef_is_leaf(treedef):
                raise ValueError(
                    "term {} for argument {} is not an array".format(j, i))

    @lu.transformation_with_aux
    def flatten_fun_output(*args):
        ans = yield args, {}
        yield tree_flatten(ans)

    f, out_tree = flatten_fun_output(lu.wrap_init(fun))
    out_primals, out_terms = jet_fun(jet_subtrace(f),
                                     order).call_wrapped(primals, series)
    return tree_unflatten(out_tree(),
                          out_primals), tree_unflatten(out_tree(), out_terms)
Beispiel #2
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_vals, in_tree = tree_flatten((init_val, ))
    init_avals = tuple(_map(_abstractify, init_vals))
    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
        cond_fun, in_tree, init_avals)
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
        body_fun, in_tree, init_avals)
    if not treedef_is_leaf(cond_tree):
        msg = "cond_fun must return a boolean scalar, but got pytree {}."
        raise TypeError(msg.format(cond_tree))
    if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
        msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
        raise TypeError(msg.format(cond_jaxpr.out_avals))
    if not treedef_children(in_tree) == [body_tree]:
        msg = "body_fun output pytree structure must match init_val, got {} and {}."
        raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
    outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
                        cond_nconsts=len(cond_consts),
                        cond_jaxpr=cond_jaxpr,
                        body_nconsts=len(body_consts),
                        body_jaxpr=body_jaxpr)
    return tree_unflatten(body_tree, outs)
Beispiel #3
0
 def bwd(res, cts):
   jaxpr, in_tree, out_tree, consts = res
   cts_flat, out_tree_ = tree_flatten((cts,))
   if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
   cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
   cts_out = tree_unflatten(in_tree, cts_out)
   if treedef_is_leaf(in_tree):
     cts_out = (cts_out,)
   return cts_out
Beispiel #4
0
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_closed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    raise ValueError(
                        f"Conditional function modifies scope.{ms} field.")
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got pytree {cond_tree}."
            )
        if not safe_map(core.typecompat, cond_jaxpr.out_avals,
                        [core.ShapedArray((), np.bool_)]):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got output type(s) "
                f"{cond_jaxpr.out_avals}.")

        return lax_control_flow.while_p.bind(*cond_consts,
                                             *body_const_vals,
                                             *init_vals,
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_closed_jaxpr)
Beispiel #5
0
def is_leaf(x):
    return tree_util.treedef_is_leaf(tree_util.tree_flatten(x)[1])
Beispiel #6
0
Datei: jet.py Projekt: 0x0is1/jax
def jet(fun, primals, series):
    r"""Taylor-mode higher-order automatic differentiation.

  Args:
    fun: Function to be differentiated. Its arguments should be arrays, scalars,
      or standard Python containers of arrays or scalars. It should return an
      array, scalar, or standard Python container of arrays or scalars.
    primals: The primal values at which the Taylor approximation of ``fun`` should be
      evaluated. Should be either a tuple or a list of arguments,
      and its length should be equal to the number of positional parameters of
      ``fun``.
    series: Higher order Taylor-series-coefficients.
      Together, `primals` and `series` make up a truncated Taylor polynomial.
      Should be either a tuple or a list of tuples or lists,
      and its length dictates the degree of the truncated Taylor polynomial.

  Returns:
    A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
    and together, ``primals_out`` and ``series_out`` are a
    truncated Taylor polynomial of :math:`f(h(\cdot))`.
    The ``primals_out`` value has the same Python tree structure as ``primals``,
    and the ``series_out`` value the same Python tree structure as ``series``.

  For example:

  >>> import jax
  >>> import jax.numpy as np

  Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
  and the first few Taylor coefficients
  :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
  Let :math:`f(y) = \sin(y)`.

  >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
  >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

  :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
  according to Faà di Bruno's formula:

  >>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
  >>> print(f0,  f(h0))
  0.12467473 0.12467473

  >>> print(f1, df(h0) * h1)
  0.7441479 0.74414825

  >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
  2.9064622 2.9064634
  """
    try:
        order, = set(map(len, series))
    except ValueError:
        msg = "jet terms have inconsistent lengths for different arguments"
        raise ValueError(msg) from None

    # TODO(mattjj): consider supporting pytree inputs
    for i, (x, terms) in enumerate(zip(primals, series)):
        treedef = tree_structure(x)
        if not treedef_is_leaf(treedef):
            raise ValueError(
                "primal value at position {} is not an array".format(i))
        for j, t in enumerate(terms):
            treedef = tree_structure(t)
            if not treedef_is_leaf(treedef):
                raise ValueError(
                    "term {} for argument {} is not an array".format(j, i))

    @lu.transformation_with_aux
    def flatten_fun_output(*args):
        ans = yield args, {}
        yield tree_flatten(ans)

    f, out_tree = flatten_fun_output(lu.wrap_init(fun))
    out_primals, out_terms = jet_fun(jet_subtrace(f),
                                     order).call_wrapped(primals, series)
    return tree_unflatten(out_tree(),
                          out_primals), tree_unflatten(out_tree(), out_terms)