Ejemplo n.º 1
0
def custom_jvp_call_jaxpr(fun, jvp, *args):
    """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive."""
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    jvp_jaxpr_thunk = pe._memoize(  # pylint: disable=protected-access
        lambda: cd._initial_style_jaxpr(jvp, in_avals * 2))  # pylint: disable=protected-access
    return cd.custom_jvp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           jvp_jaxpr_thunk=jvp_jaxpr_thunk,
                                           num_consts=len(consts))
Ejemplo n.º 2
0
def callback_jaxpr(closed_jaxpr, callback, strip_calls):
  fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr))
  fun = callback_subtrace(fun)
  fun = _callback_fun(fun, callback, strip_calls)
  avals_in = closed_jaxpr.in_avals
  jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts)
Ejemplo n.º 3
0
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    fwd_jaxpr_thunk = pe._memoize(
        lambda: cd._initial_style_jaxpr(fwd, in_avals))  # pylint: disable=protected-access
    return cd.custom_vjp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           fwd_jaxpr_thunk=fwd_jaxpr_thunk,
                                           bwd=bwd,
                                           out_trees=out_trees,
                                           num_consts=len(consts))