def bind(self, *args, **kwargs): """Like Primitive.bind, but finds the top trace even when no arguments are provided.""" assert jax.core.skip_checks or all( isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args trace = _top_trace() assert (jax.core.skip_checks or find_top_trace(args) is None or find_top_trace(args).master is trace.master), args tracers = map(trace.full_raise, args) out_tracer = trace.process_primitive(self, tracers, kwargs) return map(full_lower, out_tracer) if self.multiple_results else full_lower(out_tracer)
def bind(self, f, *args, **params): top_trace = jax_core.find_top_trace(args) trace_stack = jax_core.thread_local_state.trace_state.trace_stack level = (trace_stack.next_level(True) if top_trace is None else top_trace.level) params_tuple = tuple(params.items()) f, env_trace_todo = jax_core.process_env_traces( f, self, level, params_tuple) if top_trace is None: with jax_core.new_sublevel(): outs = self.impl(f, *args, **params) else: tracers = safe_map(top_trace.full_raise, args) if (isinstance(top_trace, batching.BatchTrace) and self in custom_batch_rules): outs = custom_batch_rules[self](top_trace, f, tracers, params) else: if isinstance(top_trace, ad.JVPTrace): prim = self.subcall('jvp') else: prim = self outs = safe_map( jax_core.full_lower, top_trace.process_call(prim, f, tracers, params)) return jax_core.apply_todos(env_trace_todo(), outs)
def bind(self, fun, jvp, *args): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) jvp, env_trace_todo2 = core.process_env_traces( jvp, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def bind(self, call, *args, **params): # TODO(frostig,mattjj): This doesn't handle closures yet, which is # a bit involved. Closures are complicated by us binding `call` # twice in the JVP rule for custom transpose. The `env_trace_todo` # output by `process_env_traces` due to one of those two bindings # should be passable to the other, and need to be passed onward # since the second bind is deferred by partial eval (since it # typically receives unknowns) top_trace = core.find_top_trace(args) tracers = map(top_trace.full_raise, args) outs = top_trace.process_custom_transpose(self, call, tracers, **params) return outs
def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) fwd, env_trace_todo2 = core.process_env_traces( fwd, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore with core.maybe_new_sublevel(top_trace): outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, out_trees=out_trees) _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = process_env_traces( fun, self, top_trace and top_trace.level, False) fwd, env_trace_todo2 = process_env_traces_fwd( fwd, top_trace and top_trace.level, out_trees) tracers = map(top_trace.full_raise, args) # type: ignore bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args)) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, out_trees=out_trees) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if fst: return _apply_todos(env_trace_todo, map(core.full_lower, outs)) else: env_trace_todo, bwd_transform = env_trace_todo bwd = _apply_bwd_transform(bwd_transform, bwd) return _apply_todos(env_trace_todo, map(core.full_lower, outs))