Beispiel #1
0
def closure_convert(fun, in_tree, in_avals):
    if config.omnistaging_enabled:
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(
            wrapped_fun, in_avals)
    else:
        in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        with core.initial_style_staging():  # type: ignore
            jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
                wrapped_fun, in_pvals, instantiate=True,
                stage_out=False)  # type: ignore
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(mattjj): revise this approach
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(y, t, *hconsts_args):
        hoisted_consts, args = split_list(hconsts_args, [num_consts])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Beispiel #2
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
             _check_for_tracers(args[i])
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                        args)
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
                                       flat_fwd,
                                       flat_bwd,
                                       *args_flat,
                                       out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Beispiel #3
0
def _trace_to_jaxpr_with_refs(
    f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
    f, out_tree_thunk = flatten_fun_nokwargs(
        lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, state_avals)
    return jaxpr, consts, out_tree_thunk()
Beispiel #4
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
     if not self.jvp:
         msg = "No JVP defined for custom_jvp function {} using defjvp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         nondiff_argnums = set(self.nondiff_argnums)
         args = tuple(
             _stop_gradient(x) if i in nondiff_argnums else x
             for i, x in enumerate(args))
         diff_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
                                        diff_argnums, args)
         static_args = [args[i] for i in self.nondiff_argnums]
         jvp = _add_args(lu.wrap_init(self.jvp), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         jvp = lu.wrap_init(self.jvp)
     args_flat, in_tree = tree_flatten(dyn_args)
     flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
     flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
     out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
     _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
     return tree_unflatten(out_tree, out_flat)
Beispiel #5
0
    def _init_parameters_dict(self, rng, *example_inputs, reuse, reuse_only):
        flat_inputs, in_tree = tree_util.tree_flatten(example_inputs)
        flat_fun, _ = api_util.flatten_fun_nokwargs(self._wrapped_fun, in_tree)
        (jaxpr, _, consts), submodules_in_call_order = \
            parametrized._submodule_call_order_tracing.nested(
                self, lambda: pe.trace_to_jaxpr(flat_fun, parametrized._partialize(flat_inputs)),
                do_trace_submodules=True)

        submodule_params = _get_submodule_params(rng,
                                                 jaxpr,
                                                 consts, [],
                                                 OrderedDict(),
                                                 *example_inputs,
                                                 reuse=reuse,
                                                 reuse_only=reuse_only)

        # TODO cleanup, needed whenever parent of scan is used as submodule,
        # since cell tracing is leaking into modules above:
        # submodules_in_execution_order = list(
        #    filter(lambda x: x in submodule_params, submodules_in_execution_order))

        assert len(submodule_params) == len(submodules_in_call_order)

        if len(submodule_params) <= 1:
            return submodule_params

        permutation = parametrized._permutation_to_jaxpr_order(
            jaxpr, submodules_in_call_order)
        assert len(submodule_params) == len(permutation)
        submodule_param_pairs_in_call_order = list(submodule_params.items())
        submodule_param_pairs_in_jaxpr_order = list(
            submodule_param_pairs_in_call_order[i] for i in permutation)
        return OrderedDict(submodule_param_pairs_in_jaxpr_order)
Beispiel #6
0
 def fwd(*args, **kwargs):
     ans, rule = fun(*args, **kwargs)
     ans_flat, out_tree = tree_flatten((ans, ))
     rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
     ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
     return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
Beispiel #7
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   out_flat = nest_p.bind(flat_fun, *flat_args, scope=scope, mode='strict',
                          name=getattr(f, '__name__', '<no name>'))
   return tree_util.tree_unflatten(out_tree(), out_flat)
Beispiel #8
0
 def wrapped(*args):
     args_flat, in_tree = tree_flatten(args)
     f_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
     arg_pairs = [(x, jnp.zeros_like(x)) for x in args_flat]
     out_pairs_flat = doubling_transform(f_flat).call_wrapped(*arg_pairs)
     out_flat = [head + tail for head, tail in out_pairs_flat]
     out = tree_unflatten(out_tree(), out_flat)
     return out
Beispiel #9
0
def _initial_style_open_jaxpr(fun: Callable,
                              in_tree,
                              in_avals,
                              primitive_name: Optional[str] = None):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
    return jaxpr, consts, out_tree()
Beispiel #10
0
def lazy_eval(fun, *args):
    args_flat, in_tree = tree_util.tree_flatten(args)
    f = lu.wrap_init(fun)
    flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
    jaxpr, consts, _, out_avals = fastar_jaxpr(flat_fun, *args_flat)
    outs_flat = core.lazy_eval_jaxpr(jaxpr, consts, *args_flat)
    for out, aval in zip(outs_flat, out_avals):
        assert core.get_aval(out) == aval
    return tree_util.tree_unflatten(out_tree(), outs_flat)
Beispiel #11
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     with jax_core.new_main(HarvestTrace) as main:
         flat_fun = reap_function(flat_fun, main, settings, False)
         out_flat, reaps = flat_fun.call_wrapped(flat_args)
         del main
     return tree_util.tree_unflatten(out_tree(), out_flat), reaps
Beispiel #12
0
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
Beispiel #13
0
 def inner():
     flat_inputs, in_tree = tree_util.tree_flatten(inputs)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(
         self._wrapped_fun, in_tree)
     with jc.new_master(ApplyTrace) as master:
         flat_fun = ApplyTrace._apply_subtrace(
             flat_fun, master, WrapHashably(params))
         flat_outputs = flat_fun.call_wrapped(*inputs)
         del master
     return tree_util.tree_unflatten(out_tree(), flat_outputs)
Beispiel #14
0
 def wrapped(plants, *args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     all_args, all_tree = tree_util.tree_flatten((plants, flat_args))
     with jax_core.new_main(HarvestTrace) as main:
         flat_fun = plant_function(flat_fun, main, settings, all_tree)
         out_flat = flat_fun.call_wrapped(all_args)
         del main
     return tree_util.tree_unflatten(out_tree(), out_flat)
Beispiel #15
0
 def wrapped(spenv: SparseEnv, *argspecs: ArgSpec, **params: Any) -> Tuple[Sequence[ArgSpec], bool]:
   in_avals = argspecs_to_avals(spenv, argspecs)
   in_avals_flat, in_tree = tree_flatten(in_avals)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
   jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
   result = eval_sparse(jaxpr, consts, argspecs, spenv)
   if len(out_avals_flat) != len(result):
     raise Exception("Internal: eval_sparse does not return expected number of arguments. "
                     "Got {result} for avals {out_avals_flat}")
   return result, out_tree()
Beispiel #16
0
def lazy_eval_fixed_point(fun, mock_arg):
    arg_flat, in_tree = tree_util.tree_flatten([mock_arg])
    f = lu.wrap_init(fun)
    flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
    jaxpr, consts, _, out_avals = tie_the_knot(
        fastar_jaxpr(flat_fun, *arg_flat))
    outs_flat = core.lazy_eval_jaxpr(jaxpr, consts)
    for out, aval in zip(outs_flat, out_avals):
        assert core.get_aval(out) == aval
    return tree_util.tree_unflatten(out_tree(), outs_flat)
Beispiel #17
0
 def wrapped(*args, **kwargs):
   """Runs a function and binds it to a call primitive."""
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   out_tree_dest = None
   out = prim.bind(flat_fun, *flat_args, num_args=len(flat_args),
                   name=f.__name__,
                   in_tree=in_tree,
                   out_tree=lambda: out_tree_dest,
                   **params)
   out_tree_dest = out_tree()
   return tree_util.tree_unflatten(out_tree_dest, out)
Beispiel #18
0
    def wrapped(*args, **kwargs):
        """Callable returned by unzip."""
        with jax_core.new_master(UnzipTrace) as master:
            # Preparing args to be traced
            trace = UnzipTrace(master, jax_core.cur_sublevel())
            fun = lu.wrap_init(f, kwargs)
            avals = tree_util.tree_map(trace_util.get_shaped_aval, args)
            flat_avals, flat_keys, in_tree = (flatten_args_into_keys(
                avals, key_args))
            flat_pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
            flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)

            # Trace to jaxpr
            settings = UnzipSettings(tag, False)
            fun = unzip_to_init_apply_subjaxprs(flat_fun, trace, settings)  # pylint: disable=no-value-for-parameter
            success, results = fun.call_wrapped(flat_keys, flat_pvals)
            if not success:
                raise ValueError('Variables do not cut dependence graph.')
            init_out, apply_out, _, metadata = results
            init_jaxpr, init_consts, init_env = init_out
            assert not init_env

            apply_jaxpr, apply_consts, apply_env = apply_out
            assert not apply_env

            names, variable_tree, _ = metadata
            out_tree = out_tree()

            # Final functions
            def init(*args):
                flat_args, _ = tree_util.tree_flatten(args)
                flat_params = jax_core.eval_jaxpr(init_jaxpr, init_consts,
                                                  *flat_args)
                flat_variables = tree_util.tree_unflatten(
                    variable_tree, flat_params)
                return {
                    name: var
                    for name, var in safe_zip(names, flat_variables)
                }

            def apply(params, *args):
                flat_variables, _ = tree_util.tree_flatten(
                    [params[name] for name in names])
                flat_args, _ = tree_util.tree_flatten(args)
                out = jax_core.eval_jaxpr(apply_jaxpr, apply_consts,
                                          *(flat_variables + flat_args))
                return tree_util.tree_unflatten(out_tree, out)

            del master
        return init, apply
Beispiel #19
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals]
     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True,
         stage_out=True,
         trace_type=pe.StagingJaxprTrace)
     out_avals = [pval.get_aval() for pval in out_pvals]
     typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
     return typed_jaxpr, (in_tree, out_tree())
Beispiel #20
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   flat_avals = safe_map(get_shaped_aval, flat_args)
   if dynamic:
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
         flat_fun,
         flat_avals)
   else:
     pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
     jaxpr, _, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True)
   typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
   return typed_jaxpr, (in_tree, out_tree())
Beispiel #21
0
def _get_harvest_metadata(closed_jaxpr, settings, *args):
    """Probes a jaxpr for metadata like its sown values."""
    fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr))
    with jax_core.new_main(HarvestTrace) as main:
        settings = HarvestSettings(settings.tag, settings.blocklist,
                                   settings.allowlist, True)
        fun = reap_function(fun, main, settings, True)
        fun, aux = _reap_metadata_wrapper(fun)
        flat_args, in_tree = tree_util.tree_flatten(args)
        flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
        in_avals = jax_util.safe_map(
            lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
            flat_args)
        pe.trace_to_jaxpr_final(flat_fun, in_avals)
        metadata = aux()
        out_tree()
    return metadata
Beispiel #22
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if not jax.config.omnistaging_enabled:
         raise ValueError('Oryx must be used with JAX omnistaging enabled.')
     if dynamic:
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun,
                                              pvals,
                                              instantiate=True)
     typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
     return typed_jaxpr, (in_tree, out_tree())
Beispiel #23
0
def closure_convert(fun, in_tree, in_avals):
    in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    with core.initial_style_staging():
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(wrapped_fun,
                                                     in_pvals,
                                                     instantiate=True,
                                                     stage_out=False)
    out_tree = out_tree()
    num_consts = len(consts)

    def converted_fun(y, t, *consts_args):
        consts, args = split_list(consts_args, [num_consts])
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, consts
Beispiel #24
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
    out_tree = out_tree()

    (closure_consts,
     hoisted_consts), merge = partition_list(_is_perturbed, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(*args_hconsts):
        num_args = len(args_hconsts) - num_consts
        args, hoisted_consts = split_list(args_hconsts, [num_args])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten(tuple(args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Beispiel #25
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(frostig,mattjj): revise this approach
    from jax.numpy import inexact
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(*args_hconsts):
        num_args = len(args_hconsts) - num_consts
        args, hoisted_consts = split_list(args_hconsts, [num_args])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten(tuple(args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Beispiel #26
0
def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
                                  rules: Rules, state: Value,
                                  invals: Sequence[Value],
                                  call_jaxpr: jax_core.Jaxpr,
                                  **params: Any) -> Tuple[Value, Value]:
  """Handles simple call primitives like `jax_core.call_p`.

  When evaluating call primitives, the input `state` needs to be an additional
  input to the call primitive and it also needs to return an additional output
  `state`. After flattening the state along with the regular inputs, this
  handler recursively calls `eval_jaxpr_with_state` on the primitive's
  `call_jaxpr`. The output state from the recursive call is returned from the
  call primitive.

  Args:
    primitive: A `jax_core.CallPrimitive` such as `jax_core.call_p`.
    rules: A `dict` that maps JAX primitives to functions that take in `(state,
      *args)` and return `(output, new_state)`.
    state: The interpreter `state` value at the time of calling evaluating the
      call primitive.
    invals: The input values to the call primitive.
    call_jaxpr: The `jax_core.Jaxpr` that corresponds to the body of the call
      primitive.
    **params: The parameters of the call primitive.

  Returns:
    A tuple of the output of the call primitive and its output state.
  """
  # Recursively use the effect handler for the call primitive's JAXpr.
  fun = lu.wrap_init(
      functools.partial(eval_jaxpr_with_state, call_jaxpr, rules, []))

  state_invals, state_invals_tree = tree_util.tree_flatten((state, *invals))
  flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, state_invals_tree)
  ans_state = primitive.bind(flat_fun, *state_invals, **params)
  return tree_util.tree_unflatten(out_tree(), ans_state)
Beispiel #27
0
def core_closed_call(f, *args):
    args, in_tree = tree_flatten(args)
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
    out = core.closed_call_p.bind(f, *args)
    return tree_unflatten(out_tree(), out)
Beispiel #28
0
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
                linear_args):
    """Call a linear function, with a custom implementation for its transpose.

  The type signatures of ``fun`` and ``fun_transpose`` are:

  .. code-block:: haskell

    fun           :: r -> a -o b
    fun_transpose :: r -> b -o a

  where the ``-o`` arrow indicates a linear function, ``r`` is the
  residual input type and ``a`` is the linear input type.

  The functions ``fun`` and ``fun_transpose`` are coupled as
  transposes of one another. Specifically, the transpose of a
  ``linear_call`` primitive is another ``linear_call`` to
  ``fun_transpose``, with ``fun`` as its custom transposition.

  For example:

  >>> def f(r, x):
  ...   return x / r

  >>> def t(r, t):
  ...   return t / r

  >>> def div_add(x, denom):
  ...   return x + linear_call(f, t, denom, x)

  >>> def transpose(f, x_example):
  ...   def transposed(y):
  ...     x, = jax.linear_transpose(f, x_example)(y)
  ...     return x
  ...   return transposed

  >>> div_add(9., 3.)
  DeviceArray(12., dtype=float32)

  >>> transpose(partial(div_add, denom=3.), 1.)(18.)  # custom
  DeviceArray(24., dtype=float32)

  >>> transpose(lambda x: x + x / 3., 1.)(18.)  # reference
  DeviceArray(24., dtype=float32)

  The above definition of ``f`` illustrates the purpose of a residual
  argument: division is linear in one of its inputs (the dividend
  ``x``) but not the other (the divisor ``r``).

  As another example:

  >>> def custom_id(x):
  ...   def f(_, x): return x
  ...   def t(_, t): return 7.
  ...   return linear_call(f, t, (), x)
  >>> custom_id(1.)
  1.0
  >>> transpose(custom_id, 1.)(1.)
  7.0
  >>> transpose(transpose(custom_id, 1.), 1.)(1.)
  1.0
  >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
  7.0

  Args:
    fun: a Python callable specifying a linear function. It should
      take two arguments: one of "residual" inputs (type ``r``),
      i.e. inputs in which the function is not necessarly linear, and
      one of "linear" inputs (type ``a``).  It should return output
      whose components are linear in the linear input (type ``b``).
    fun_transpose: a Python callable specifying a structurally linear
      function that is the transpose of ``fun`` with respect to its
      linear inputs. Its first argument is the same residual inputs
      (``r``) as ``fun``. Its second argument is of type
      ``b``. Finally, its output is of type ``a`` and each of its
      component are linear in its second argument (the ``b`` inputs).
    residual_args: Argument in which ``fun`` and ``fun_transpose`` are
      not necessarily linear. Not involved in transposition.
    linear_args: Argument in which ``fun`` and ``fun_transpose`` are
      linear and with respect to which the two are transposes.

  Returns:
    The call result, i.e. ``fun(residual_args, linear_args)``.

  """
    operands_res, res_tree = tree_flatten(residual_args)
    operands_lin, lin_tree = tree_flatten(linear_args)

    f_in_tree = treedef_tuple((res_tree, lin_tree))
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

    res_avals = map(abstractify, operands_res)
    lin_avals = map(abstractify, operands_lin)
    f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
    f_jaxpr = _close_jaxpr(f_jaxpr)
    out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)

    t_in_tree = treedef_tuple((res_tree, out_tree()))
    t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose),
                                         t_in_tree)

    t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
    t_jaxpr = _close_jaxpr(t_jaxpr)

    if t_out_tree() != lin_tree:
        raise TypeError(
            'transpose output pytree structure must match that of linear inputs, '
            f'got output structure {t_out_tree()} '
            f'and input structure {lin_tree}.')

    out = linear_call_p.bind(*f_consts,
                             *t_consts,
                             *operands_res,
                             *operands_lin,
                             callee=f_jaxpr,
                             transpose=t_jaxpr,
                             num_callee_consts=len(f_consts),
                             num_transpose_consts=len(t_consts),
                             num_res=len(operands_res))

    return tree_unflatten(out_tree(), out)
Beispiel #29
0
 def wrapped_fun(*args):
     args_flat, in_tree = tree_flatten(args)
     f = lu.wrap_init(fun)
     flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
     out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls)
     return tree_unflatten(out_tree(), out_flat)