def custom_jvp_call_jaxpr(fun, jvp, *args): """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive.""" in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) jvp_jaxpr_thunk = pe._memoize( # pylint: disable=protected-access lambda: cd._initial_style_jaxpr(jvp, in_avals * 2)) # pylint: disable=protected-access return cd.custom_jvp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def callback_jaxpr(closed_jaxpr, callback, strip_calls): fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr)) fun = callback_subtrace(fun) fun = _callback_fun(fun, callback, strip_calls) avals_in = closed_jaxpr.in_avals jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in) return core.ClosedJaxpr(jaxpr_out, consts)
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees): in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) fwd_jaxpr_thunk = pe._memoize( lambda: cd._initial_style_jaxpr(fwd, in_avals)) # pylint: disable=protected-access return cd.custom_vjp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees, num_consts=len(consts))