def bind(self, fun, jvp, *args): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) jvp, env_trace_todo2 = core.process_env_traces( jvp, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def bind(self, f, *args, **params): top_trace = jax_core.find_top_trace(args) trace_stack = jax_core.thread_local_state.trace_state.trace_stack level = (trace_stack.next_level(True) if top_trace is None else top_trace.level) params_tuple = tuple(params.items()) f, env_trace_todo = jax_core.process_env_traces( f, self, level, params_tuple) if top_trace is None: with jax_core.new_sublevel(): outs = self.impl(f, *args, **params) else: tracers = safe_map(top_trace.full_raise, args) if (isinstance(top_trace, batching.BatchTrace) and self in custom_batch_rules): outs = custom_batch_rules[self](top_trace, f, tracers, params) else: if isinstance(top_trace, ad.JVPTrace): prim = self.subcall('jvp') else: prim = self outs = safe_map( jax_core.full_lower, top_trace.process_call(prim, f, tracers, params)) return jax_core.apply_todos(env_trace_todo(), outs)
def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) fwd, env_trace_todo2 = core.process_env_traces( fwd, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore with core.maybe_new_sublevel(top_trace): outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, out_trees=out_trees) _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))