Exemple #1
0
    def handle_call_primitive(self, call_primitive, f, tracers, params,
                              is_map):
        """Handler for call_primitives, like jit or layer_call.

    When an UnzipTracer hits a call primitive, there is either a variable
    inside of the call primitive, in which case the input
    function needs to be unzipped into two, or there are no variables
    in the function, so the call_primitive is recorded in the trace as-is.

    We use `unzip_eval_wrapper`, which returns whether or not an unzip
    was successful or not. If it was successful, we record two new
    Jaxprs into the trace (one for init, one for apply). Otherwise, we
    just record the Jaxpr corresponding to the function call.

    Args:
      call_primitive: a call primitive like xla_call
      f: a jax.linear_util wrapped function to be called
      tracers: inputs to the function
      params: parameters of the primitives
      is_map: whether or not the primitive is a map primitive (e.g. xla_pmap)

    Returns:
      A list of output tracers
    """
        name = params.get('name', f.__name__)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const_abstracted, tracers)
        if call_primitive in current_custom_rules():
            return current_custom_rules()[call_primitive](self, f, *tracers,
                                                          **params)
        if call_primitive in pe.call_partial_eval_rules:
            raise NotImplementedError
        in_pvals = [t.pval for t in tracers]
        if is_map:
            unknown = pe.PartialVal.unknown
            in_pvals = [
                pval if pval.is_known() or in_axis is None else unknown(
                    mapped_aval(params['axis_size'], in_axis, pval[0]))
                for pval, in_axis in zip(in_pvals, params['in_axes'])
            ]
        pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
        keys = tuple(t.is_key() for t in tracers)
        new_settings = UnzipSettings(settings.tag, call_primitive
                                     in block_registry)
        fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
        out_flat = call_primitive.bind(fun, *in_consts, **params)
        success, results = aux()
        if not success:
            out_pvs, out_keys, jaxpr, env = results
            out_pv_consts, consts = jax_util.split_list(
                out_flat, [len(out_pvs)])
            out_tracers = self._bound_output_tracers(call_primitive, params,
                                                     jaxpr, consts, env,
                                                     tracers, out_pvs,
                                                     out_pv_consts, out_keys,
                                                     name, is_map)
            return out_tracers
        init_name = jax_util.wrap_name(name, 'init')
        apply_name = jax_util.wrap_name(name, 'apply')
        init_pvs, num_init_consts, apply_pvs = results[0]
        init_jaxpr, apply_jaxpr = results[1]
        init_env, apply_env = results[2]
        variable_names, variable_tree, apply_keys = results[3]

        key_tracers = [t for t in tracers if t.is_key()]
        abstract_tracers = [t for t in tracers if not t.is_key()]
        all_init_consts, all_apply_consts = jax_util.split_list(
            out_flat, [len(init_pvs) + num_init_consts])
        init_pv_consts, init_consts = jax_util.split_list(
            all_init_consts, [len(init_pvs)])
        apply_pv_consts, apply_consts = jax_util.split_list(
            all_apply_consts, [len(apply_pvs)])

        variable_tracers = self._bound_output_tracers(call_primitive, params,
                                                      init_jaxpr, init_consts,
                                                      init_env, key_tracers,
                                                      init_pvs, init_pv_consts,
                                                      [True] * len(init_pvs),
                                                      init_name, is_map)

        unflat_variables = tree_util.tree_unflatten(variable_tree,
                                                    variable_tracers)
        if call_primitive is harvest.nest_p:
            variable_dict = harvest.sow(dict(
                safe_zip(variable_names, unflat_variables)),
                                        tag=settings.tag,
                                        name=params['scope'],
                                        mode='strict')
            unflat_variables = tuple(variable_dict[name]
                                     for name in variable_names)
        else:
            unflat_variables = [
                harvest.sow(  # pylint: disable=g-complex-comprehension
                    unflat_variable,
                    tag=settings.tag,
                    name=name,
                    mode='strict') for unflat_variable, name in safe_zip(
                        unflat_variables, variable_names)
            ]
        variable_tracers = tree_util.tree_leaves(unflat_variables)

        out_tracers = self._bound_output_tracers(
            call_primitive, params, apply_jaxpr, apply_consts, apply_env,
            variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
            apply_keys, apply_name, is_map)
        return out_tracers
Exemple #2
0
 def params(self):
     return tree_util.tree_unflatten(self.params_tree, self.params_flat)
Exemple #3
0
def unflatten_tree(tree, xs):
    """Inverse operation of `flatten_tree`."""
    return tree_util.tree_unflatten(tree_util.tree_structure(tree), xs)
Exemple #4
0
 def structured(self):
   if self._structured is None:
     self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
   return self._structured
Exemple #5
0
def block(arrays):
  leaves, tree = tree_flatten(arrays, is_leaf=lambda a: isinstance(a, JaxArray))
  leaves = [(l.value if isinstance(l, JaxArray) else l) for l in leaves]
  arrays = tree_unflatten(tree, leaves)
  return JaxArray(jnp.block(arrays))
Exemple #6
0
def flat_propagate(tree, *flat_invals):
    invals, outvals = tree_util.tree_unflatten(tree, flat_invals)
    subenv = yield ((invals, outvals), {})
    subenv_vals, subenv_tree = tree_util.tree_flatten(subenv)
    yield subenv_vals, subenv_tree
Exemple #7
0
 def testRoundtripWithFlattenUpTo(self, inputs):
   _, tree = tree_util.tree_flatten(inputs)
   xs = tree.flatten_up_to(inputs)
   actual = tree_util.tree_unflatten(tree, xs)
   self.assertEqual(actual, inputs)
Exemple #8
0
 def doit():
   f = lu.wrap_init(fun)
   args_flat, in_tree = tree_util.tree_flatten((args, {}))
   flat_fun, out_tree = flatten_fun(f, in_tree)
   out_flat = _interpret_fun(flat_fun, args_flat)
   return tree_util.tree_unflatten(out_tree(), out_flat)
Exemple #9
0
def scan(f, init, xs):
    """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
    num_carry = len(tree_flatten(init)[0])
    in_flat, in_tree = tree_flatten((init, xs))
    init_flat, xs_flat = in_flat[:num_carry], in_flat[num_carry:]
    try:
        length, = {x.shape[0] for x in xs_flat}
    except AttributeError:
        msg = "scan got value with no leading axis to scan over: {}."
        raise ValueError(
            msg.format([x for x in xs_flat if not hasattr(x, 'shape')]))
    except ValueError:
        msg = "scan got values with different leading axis sizes: {}."
        raise ValueError(msg.format([x.shape[0] for x in xs_flat]))

    carry_avals = tuple(_map(_abstractify, init_flat))
    x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
    x_dtypes = [x.dtype for x in xs_flat]
    x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree,
                                                   carry_avals + x_avals)
    carry_avals_out, y_avals = split_list(jaxpr.out_avals, [num_carry])
    if tuple(carry_avals_out) != carry_avals:
        msg = "scan carry output type must match carry input type, got {} and {}."
        raise TypeError(msg.format(tuple(carry_avals_out), carry_avals))
    out = scan_p.bind(*itertools.chain(consts, in_flat),
                      forward=True,
                      length=length,
                      jaxpr=jaxpr,
                      num_consts=len(consts),
                      num_carry=num_carry,
                      linear=(False, ) * (len(consts) + len(in_flat)))
    return tree_unflatten(out_tree, out)
Exemple #10
0
 def checked_fun(*args, **kwargs):
     args_flat, in_tree = tree_flatten((args, kwargs))
     f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
     (err, code, out_flat), msgs = checkify_flat(f, errors, *args_flat)
     out = tree_unflatten(out_tree(), out_flat)
     return Error(err, code, msgs), out
Exemple #11
0
 def unravel_pytree(arr):
     return tree_unflatten(treedef, unravel_list(arr))
Exemple #12
0
 def _wrapped(*args):
     args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo)
     wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                  in_tree)
     out = sparsify_fun(wrapped_fun, args_flat)
     return tree_unflatten(out_tree(), out)
Exemple #13
0
def jet(fun, primals, series):
    r"""Taylor-mode higher-order automatic differentiation.

  Args:
    fun: Function to be differentiated. Its arguments should be arrays, scalars,
      or standard Python containers of arrays or scalars. It should return an
      array, scalar, or standard Python container of arrays or scalars.
    primals: The primal values at which the Taylor approximation of ``fun`` should be
      evaluated. Should be either a tuple or a list of arguments,
      and its length should be equal to the number of positional parameters of
      ``fun``.
    series: Higher order Taylor-series-coefficients.
      Together, `primals` and `series` make up a truncated Taylor polynomial.
      Should be either a tuple or a list of tuples or lists,
      and its length dictates the degree of the truncated Taylor polynomial.

  Returns:
    A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
    and together, ``primals_out`` and ``series_out`` are a
    truncated Taylor polynomial of :math:`f(h(\cdot))`.
    The ``primals_out`` value has the same Python tree structure as ``primals``,
    and the ``series_out`` value the same Python tree structure as ``series``.

  For example:

  >>> import jax
  >>> import jax.numpy as np

  Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
  and the first few Taylor coefficients
  :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
  Let :math:`f(y) = \sin(y)`.

  >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
  >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

  :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
  according to Faà di Bruno's formula:

  >>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
  >>> print(f0,  f(h0))
  0.12467473 0.12467473

  >>> print(f1, df(h0) * h1)
  0.7441479 0.74414825

  >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
  2.9064622 2.9064634
  """
    try:
        order, = set(map(len, series))
    except ValueError:
        msg = "jet terms have inconsistent lengths for different arguments"
        raise ValueError(msg) from None

    # TODO(mattjj): consider supporting pytree inputs
    for i, (x, terms) in enumerate(zip(primals, series)):
        treedef = tree_structure(x)
        if not treedef_is_leaf(treedef):
            raise ValueError(f"primal value at position {i} is not an array")
        for j, t in enumerate(terms):
            treedef = tree_structure(t)
            if not treedef_is_leaf(treedef):
                raise ValueError(f"term {j} for argument {i} is not an array")

    @lu.transformation_with_aux
    def flatten_fun_output(*args):
        ans = yield args, {}
        yield tree_flatten(ans)

    f, out_tree = flatten_fun_output(lu.wrap_init(fun))
    out_primals, out_terms = jet_fun(jet_subtrace(f),
                                     order).call_wrapped(primals, series)
    return tree_unflatten(out_tree(),
                          out_primals), tree_unflatten(out_tree(), out_terms)
Exemple #14
0
def sow_unzip(in_tracers, out_tracers, name=None, tree=None, tag=None, **_):
    del tag
    if tree:
        in_tracers = tree_util.tree_unflatten(tree, in_tracers)
        out_tracers = tree_util.tree_unflatten(tree, out_tracers)
    return name, in_tracers, out_tracers
Exemple #15
0
def _optimization_barrier(arg):
    flat_args, treedef = tree_flatten(arg)
    return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
Exemple #16
0
 def tree_get_params(opt_state):
     states_flat, tree, subtrees = opt_state
     states = map(tree_unflatten, subtrees, states_flat)
     params = map(get_params, states)
     return tree_unflatten(tree, params)
Exemple #17
0
 def wrapped(*args, **params):
     spenv = SparseEnv()
     argspecs = arrays_to_argspecs(spenv, args)
     argspecs_out, out_tree = f_raw(spenv, *argspecs, **params)
     out = argspecs_to_arrays(spenv, argspecs_out)
     return tree_unflatten(out_tree, out)
Exemple #18
0
 def _cau_jaxpr(self, *args, **kwargs):
     flat_args = tree_util.tree_leaves(args)
     out_flat = eval_jaxpr_with_kwargs(self._jaxpr.jaxpr,
                                       self._jaxpr.literals, *flat_args,
                                       **kwargs)
     return tree_util.tree_unflatten(self._out_tree, out_flat)
Exemple #19
0
 def tree_unflatten(cls, meta, data):
   if not tree_util.all_leaves(data):
     data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
   return FlatCache(None, leaves=data, treedef=meta)
Exemple #20
0
 def todo(x):
     primals, series = tree_unflatten(treedef, x)
     trace = JetTrace(main, core.cur_sublevel())
     return map(partial(JetTracer, trace), primals, series)
Exemple #21
0
 def testRoundtripIsLeaf(self, tree):
   xs, treedef = tree_util.tree_flatten(
       tree, is_leaf=lambda t: isinstance(t, tuple))
   recon_tree = tree_util.tree_unflatten(treedef, xs)
   self.assertEqual(recon_tree, tree)
Exemple #22
0
def traceable(in_tree_def, *primals_and_series):
    primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
    primals_out, series_out = yield (primals_in, series_in), {}
    out_flat, out_tree_def = tree_flatten((primals_out, series_out))
    yield out_flat, out_tree_def
Exemple #23
0
 def func(flat_args):
   unflat_args = tree_util.tree_unflatten(in_tree, flat_args)
   return fn(unflat_args)
Exemple #24
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist,
                           mode=settings.mode)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys
Exemple #25
0
def lexsort(keys, axis=-1):
  leaves, tree = tree_flatten(keys, is_leaf=lambda x: isinstance(x, JaxArray))
  leaves = [_remove_jaxarray(l) for l in leaves]
  keys = tree_unflatten(tree, leaves)
  return JaxArray(jnp.lexsort(keys, axis))
Exemple #26
0
    def par_from_array(arr):
        value_flat = jnp.split(arr, section_sizes)
        value_flat = [x.reshape(s) for x, s in zip(value_flat, section_shapes)]

        params = tree_unflatten(value_tree, value_flat)
        return params
Exemple #27
0
 def testRoundtrip(self, inputs):
     xs, tree = tree_util.tree_flatten(inputs)
     actual = tree_util.tree_unflatten(tree, xs)
     self.assertEqual(actual, inputs)
Exemple #28
0
 def tree_get_params(opt_state):
     packed_state, tree, subtrees = opt_state
     states = map(tree_unflatten, subtrees, packed_state)
     params = map(get_params, states)
     return tree_unflatten(tree, params)
Exemple #29
0
def _psum_transpose_rule(cts, axis_name, axis_index_groups):
    nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
    nonzero_in_cts = psum_p.bind(*nonzero_out_cts,
                                 axis_name=axis_name,
                                 axis_index_groups=axis_index_groups)
    return tree_util.tree_unflatten(treedef, nonzero_in_cts)
Exemple #30
0
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
                linear_args):
    """Call a linear function, with a custom implementation for its transpose.

  The type signatures of ``fun`` and ``fun_transpose`` are:

  .. code-block:: haskell

    fun           :: r -> a -o b
    fun_transpose :: r -> b -o a

  where the ``-o`` arrow indicates a linear function, ``r`` is the
  residual input type and ``a`` is the linear input type.

  The functions ``fun`` and ``fun_transpose`` are coupled as
  transposes of one another. Specifically, the transpose of a
  ``linear_call`` primitive is another ``linear_call`` to
  ``fun_transpose``, with ``fun`` as its custom transposition.

  For example:

  >>> def f(r, x):
  ...   return x / r

  >>> def t(r, t):
  ...   return t / r

  >>> def div_add(x, denom):
  ...   return x + linear_call(f, t, denom, x)

  >>> def transpose(f, x_example):
  ...   def transposed(y):
  ...     x, = jax.linear_transpose(f, x_example)(y)
  ...     return x
  ...   return transposed

  >>> div_add(9., 3.)
  DeviceArray(12., dtype=float32, weak_type=True)

  >>> transpose(partial(div_add, denom=3.), 1.)(18.)  # custom
  DeviceArray(24., dtype=float32, weak_type=True)

  >>> transpose(lambda x: x + x / 3., 1.)(18.)  # reference
  DeviceArray(24., dtype=float32, weak_type=True)

  The above definition of ``f`` illustrates the purpose of a residual
  argument: division is linear in one of its inputs (the dividend
  ``x``) but not the other (the divisor ``r``).

  As another example:

  >>> def custom_id(x):
  ...   def f(_, x): return x
  ...   def t(_, t): return 7.
  ...   return linear_call(f, t, (), x)
  >>> custom_id(1.)
  1.0
  >>> transpose(custom_id, 1.)(1.)
  7.0
  >>> transpose(transpose(custom_id, 1.), 1.)(1.)
  1.0
  >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
  7.0

  Args:
    fun: a Python callable specifying a linear function. It should
      take two arguments: one of "residual" inputs (type ``r``),
      i.e. inputs in which the function is not necessarly linear, and
      one of "linear" inputs (type ``a``).  It should return output
      whose components are linear in the linear input (type ``b``).
    fun_transpose: a Python callable specifying a structurally linear
      function that is the transpose of ``fun`` with respect to its
      linear inputs. Its first argument is the same residual inputs
      (``r``) as ``fun``. Its second argument is of type
      ``b``. Finally, its output is of type ``a`` and each of its
      component are linear in its second argument (the ``b`` inputs).
    residual_args: Argument in which ``fun`` and ``fun_transpose`` are
      not necessarily linear. Not involved in transposition.
    linear_args: Argument in which ``fun`` and ``fun_transpose`` are
      linear and with respect to which the two are transposes.

  Returns:
    The call result, i.e. ``fun(residual_args, linear_args)``.

  """
    operands_res, res_tree = tree_flatten(residual_args)
    operands_lin, lin_tree = tree_flatten(linear_args)

    f_in_tree = treedef_tuple((res_tree, lin_tree))
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

    res_avals = map(abstractify, operands_res)
    lin_avals = map(abstractify, operands_lin)
    f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
    f_jaxpr = _close_jaxpr(f_jaxpr)
    out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)

    t_in_tree = treedef_tuple((res_tree, out_tree()))
    t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose),
                                         t_in_tree)

    t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
    t_jaxpr = _close_jaxpr(t_jaxpr)

    if t_out_tree() != lin_tree:
        raise TypeError(
            'transpose output pytree structure must match that of linear inputs, '
            f'got output structure {t_out_tree()} '
            f'and input structure {lin_tree}.')

    out = linear_call_p.bind(*f_consts,
                             *t_consts,
                             *operands_res,
                             *operands_lin,
                             callee=f_jaxpr,
                             transpose=t_jaxpr,
                             num_callee_consts=len(f_consts),
                             num_transpose_consts=len(t_consts),
                             num_res=len(operands_res))

    return tree_unflatten(out_tree(), out)