Example #1
0
 def bind(self, fun, jvp, *args):
   args = map(core.full_lower, args)
   top_trace = core.find_top_trace(args)
   fun, env_trace_todo1 = core.process_env_traces(
       fun, self, top_trace and top_trace.level, (), None)
   jvp, env_trace_todo2 = core.process_env_traces(
       jvp, self, top_trace and top_trace.level, (), None)
   tracers = map(top_trace.full_raise, args)  # type: ignore
   outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)  # type: ignore
   _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
   return _apply_todos(env_trace_todo, map(core.full_lower, outs))
Example #2
0
 def bind(self, f, *args, **params):
     top_trace = jax_core.find_top_trace(args)
     trace_stack = jax_core.thread_local_state.trace_state.trace_stack
     level = (trace_stack.next_level(True)
              if top_trace is None else top_trace.level)
     params_tuple = tuple(params.items())
     f, env_trace_todo = jax_core.process_env_traces(
         f, self, level, params_tuple)
     if top_trace is None:
         with jax_core.new_sublevel():
             outs = self.impl(f, *args, **params)
     else:
         tracers = safe_map(top_trace.full_raise, args)
         if (isinstance(top_trace, batching.BatchTrace)
                 and self in custom_batch_rules):
             outs = custom_batch_rules[self](top_trace, f, tracers, params)
         else:
             if isinstance(top_trace, ad.JVPTrace):
                 prim = self.subcall('jvp')
             else:
                 prim = self
             outs = safe_map(
                 jax_core.full_lower,
                 top_trace.process_call(prim, f, tracers, params))
     return jax_core.apply_todos(env_trace_todo(), outs)
Example #3
0
 def bind(self, fun, fwd, bwd, *args, out_trees):
     args = map(core.full_lower, args)
     top_trace = core.find_top_trace(args)
     fun, env_trace_todo1 = core.process_env_traces(
         fun, self, top_trace and top_trace.level, (), None)
     fwd, env_trace_todo2 = core.process_env_traces(
         fwd, self, top_trace and top_trace.level, (), None)
     tracers = map(top_trace.full_raise, args)  # type: ignore
     with core.maybe_new_sublevel(top_trace):
         outs = top_trace.process_custom_vjp_call(self,
                                                  fun,
                                                  fwd,
                                                  bwd,
                                                  tracers,
                                                  out_trees=out_trees)
     _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1,
                                             env_trace_todo2)
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))