Ejemplo n.º 1
0
def _trace_to_jaxpr_with_refs(
    f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
    f, out_tree_thunk = flatten_fun_nokwargs(
        lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, state_avals)
    return jaxpr, consts, out_tree_thunk()
Ejemplo n.º 2
0
def custom_transpose_transpose_rule(cts, *args, out_types, res_tree, lin_tree,
                                    out_tree, **params):

    if 'transpose_jaxpr_thunk' in params:
        assert 'call_jaxpr' in params
        transpose = make_transpose_from_thunk(params['transpose_jaxpr_thunk'],
                                              lin_tree)
    else:
        assert 'call' in params
        transpose = params['transpose']

    call_in_tree = treedef_tuple((res_tree, lin_tree))

    # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
    # to which we are transposing (via `ad.is_undefined_primal`).
    # Consider passing this information to the custom transpose rule?

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Ejemplo n.º 3
0
 def testTreedefTupleFromChildren(self):
   # https://github.com/google/jax/issues/7377
   tree = ((1, 2, (3, 4)), (5,))
   leaves, treedef1 = tree_util.tree_flatten(tree)
   treedef2 = tree_util.treedef_tuple(treedef1.children())
   self.assertEqual(treedef1.num_leaves, len(leaves))
   self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
   self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
Ejemplo n.º 4
0
 def testTreedefTupleFromChildren(self):
   # https://github.com/google/jax/issues/7377
   # TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
   if jax.lib._xla_extension_version < 29:
     self.skipTest("fixed in future jaxlib")
   tree = ((1, 2, (3, 4)), (5,))
   leaves, treedef1 = tree_util.tree_flatten(tree)
   treedef2 = tree_util.treedef_tuple(treedef1.children())
   self.assertEqual(treedef1.num_leaves, len(leaves))
   self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
   self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
Ejemplo n.º 5
0
def custom_transpose_transpose_rule(cts, *args, call, rule, res_tree, lin_tree,
                                    out_tree):
    call_in_tree = treedef_tuple((res_tree, lin_tree))

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Ejemplo n.º 6
0
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
                linear_args):
    """Call a linear function, with a custom implementation for its transpose.

  The type signatures of ``fun`` and ``fun_transpose`` are:

  .. code-block:: haskell

    fun           :: r -> a -o b
    fun_transpose :: r -> b -o a

  where the ``-o`` arrow indicates a linear function, ``r`` is the
  residual input type and ``a`` is the linear input type.

  The functions ``fun`` and ``fun_transpose`` are coupled as
  transposes of one another. Specifically, the transpose of a
  ``linear_call`` primitive is another ``linear_call`` to
  ``fun_transpose``, with ``fun`` as its custom transposition.

  For example:

  >>> def f(r, x):
  ...   return x / r

  >>> def t(r, t):
  ...   return t / r

  >>> def div_add(x, denom):
  ...   return x + linear_call(f, t, denom, x)

  >>> def transpose(f, x_example):
  ...   def transposed(y):
  ...     x, = jax.linear_transpose(f, x_example)(y)
  ...     return x
  ...   return transposed

  >>> div_add(9., 3.)
  DeviceArray(12., dtype=float32, weak_type=True)

  >>> transpose(partial(div_add, denom=3.), 1.)(18.)  # custom
  DeviceArray(24., dtype=float32, weak_type=True)

  >>> transpose(lambda x: x + x / 3., 1.)(18.)  # reference
  DeviceArray(24., dtype=float32, weak_type=True)

  The above definition of ``f`` illustrates the purpose of a residual
  argument: division is linear in one of its inputs (the dividend
  ``x``) but not the other (the divisor ``r``).

  As another example:

  >>> def custom_id(x):
  ...   def f(_, x): return x
  ...   def t(_, t): return 7.
  ...   return linear_call(f, t, (), x)
  >>> custom_id(1.)
  1.0
  >>> transpose(custom_id, 1.)(1.)
  7.0
  >>> transpose(transpose(custom_id, 1.), 1.)(1.)
  1.0
  >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
  7.0

  Args:
    fun: a Python callable specifying a linear function. It should
      take two arguments: one of "residual" inputs (type ``r``),
      i.e. inputs in which the function is not necessarly linear, and
      one of "linear" inputs (type ``a``).  It should return output
      whose components are linear in the linear input (type ``b``).
    fun_transpose: a Python callable specifying a structurally linear
      function that is the transpose of ``fun`` with respect to its
      linear inputs. Its first argument is the same residual inputs
      (``r``) as ``fun``. Its second argument is of type
      ``b``. Finally, its output is of type ``a`` and each of its
      component are linear in its second argument (the ``b`` inputs).
    residual_args: Argument in which ``fun`` and ``fun_transpose`` are
      not necessarily linear. Not involved in transposition.
    linear_args: Argument in which ``fun`` and ``fun_transpose`` are
      linear and with respect to which the two are transposes.

  Returns:
    The call result, i.e. ``fun(residual_args, linear_args)``.

  """
    operands_res, res_tree = tree_flatten(residual_args)
    operands_lin, lin_tree = tree_flatten(linear_args)

    f_in_tree = treedef_tuple((res_tree, lin_tree))
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

    res_avals = map(abstractify, operands_res)
    lin_avals = map(abstractify, operands_lin)
    f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
    f_jaxpr = _close_jaxpr(f_jaxpr)
    out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)

    t_in_tree = treedef_tuple((res_tree, out_tree()))
    t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose),
                                         t_in_tree)

    t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
    t_jaxpr = _close_jaxpr(t_jaxpr)

    if t_out_tree() != lin_tree:
        raise TypeError(
            'transpose output pytree structure must match that of linear inputs, '
            f'got output structure {t_out_tree()} '
            f'and input structure {lin_tree}.')

    out = linear_call_p.bind(*f_consts,
                             *t_consts,
                             *operands_res,
                             *operands_lin,
                             callee=f_jaxpr,
                             transpose=t_jaxpr,
                             num_callee_consts=len(f_consts),
                             num_transpose_consts=len(t_consts),
                             num_res=len(operands_res))

    return tree_unflatten(out_tree(), out)
Ejemplo n.º 7
0
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
    def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
        in_batched_ps, in_batched_ts = in_batched

        mutually_batched = tree_map(operator.and_, in_batched_ps,
                                    in_batched_ts)
        extra_batched_ps = tree_map(
            lambda pb, tb: 0
            if pb and not tb else None, in_batched_ps, in_batched_ts)
        extra_batched_ts = tree_map(
            lambda pb, tb: 0
            if tb and not pb else None, in_batched_ps, in_batched_ts)

        out_mutually_batched = lu.Store()
        flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents))
        flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten(
            (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None)

        # TODO(frostig): assert these also equal:
        #   treedef_tuple((in_tree, in_tree))
        # once https://github.com/google/jax/issues/9066 is fixed
        assert tree_ps_ts == tree_ps_ts2
        del tree_ps_ts2

        def to_jvp(*primals):
            out, out_batched = call_rule(rule, axis_size, mutually_batched,
                                         primals)
            check_vmap_rule_trees(rule, out_tree, tree_structure(out),
                                  tree_structure(out_batched))
            out_mutually_batched.store(out_batched)
            return out

        def to_vmap_over_extra_batched_dims(primals, tangents):
            return jax.jvp(to_jvp, primals, tangents)

        to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
            lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts)

        flat_out_ps_ts, flat_out_axes = vmap_unrestricted(
            to_vmap_over_extra_batched_dims_flat,
            *flat_ps_ts,
            in_axes=flat_extra_batched_ps_ts,
            axis_name=core.no_axis_name,
            axis_size=axis_size)

        n, ragged = divmod(len(flat_out_ps_ts), 2)
        assert not ragged
        flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:]
        flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:]
        flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p)
        flat_out_extra_batched_ps = [
            d is not not_mapped for d in flat_out_axes_p
        ]
        flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t)
        flat_out_extra_batched_ts = [
            d is not not_mapped for d in flat_out_axes_t
        ]

        out_ps, out_ts = tree_unflatten(out_tree2(),
                                        [*flat_out_ps, *flat_out_ts])
        out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
            out_tree2(),
            [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

        out_batched_ps = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ps)
        out_batched_ts = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ts)

        return (out_ps, out_ts), (out_batched_ps, out_batched_ts)

    tangents = map(ad.instantiate_zeros, tangents)
    jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
    jvp_in_tree = treedef_tuple((in_tree, in_tree))
    jvp_out_tree = treedef_tuple((out_tree, out_tree))
    outs = custom_vmap_p.bind(*primals,
                              *tangents,
                              call=jvp_call,
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents
Ejemplo n.º 8
0
 def testTreedefTupleComparesEqual(self):
   # https://github.com/google/jax/issues/9066
   self.assertEqual(tree_util.tree_structure((3,)),
                    tree_util.treedef_tuple((tree_util.tree_structure(3),)))
Ejemplo n.º 9
0
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
    """Differentiably solve for a roots of a function.

  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of custom_root() are defined with respect to closed-over variables
  from the provided function ``f`` via the implicit function theorem:
  https://en.wikipedia.org/wiki/Implicit_function_theorem

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::

        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)

    tangent_solve: function to solve the tangent system. Should take two
      positional arguments, a linear function ``g`` (the function ``f``
      linearized at its root) and a tree of array(s) ``y`` with the same
      structure as initial_guess, and return a solution ``x`` such that
      ``g(x)=y``:

      - For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
      - For vector ``y``, you could use a linear solve with the Jacobian, if
        dimensionality of ``y`` is not too large:
        ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
    has_aux: bool indicating whether the ``solve`` function returns
      auxiliary data like solver diagnostics as a second argument.

  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """
    guess_flat, in_args_tree = tree_flatten((initial_guess, ))
    guess_avals = tuple(_map(_abstractify, guess_flat))
    f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(f, in_args_tree,
                                                       guess_avals)

    in_tree, = treedef_children(in_args_tree)
    _check_tree("f", "initial_guess", out_tree, in_tree, False)

    solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
        partial(solve, f), in_args_tree, guess_avals)
    _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)

    def linearize_and_solve(x, b):
        unchecked_zeros, f_jvp = jax.linearize(f, x)
        return tangent_solve(f_jvp, b)

    l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
        linearize_and_solve, treedef_tuple((in_tree, ) * 2), guess_avals * 2)
    _check_tree("tangent_solve", "x", out_tree, in_tree, False)

    all_consts = [f_consts, solve_consts, l_and_s_consts]
    const_lengths = _RootTuple(*_map(len, all_consts))
    jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)

    solution_flat = _custom_root(const_lengths, jaxprs,
                                 *(_flatten(all_consts) + guess_flat))
    return tree_unflatten(solution_tree, solution_flat)