示例#1
0
文件: core.py 项目: j-towns/jaxnet
def bind(self, *args, **kwargs):
    """Like Primitive.bind, but finds the top trace even when no arguments are provided."""
    assert jax.core.skip_checks or all(
        isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args

    trace = _top_trace()

    assert (jax.core.skip_checks or find_top_trace(args) is None
            or find_top_trace(args).master is trace.master), args

    tracers = map(trace.full_raise, args)
    out_tracer = trace.process_primitive(self, tracers, kwargs)
    return map(full_lower,
               out_tracer) if self.multiple_results else full_lower(out_tracer)
示例#2
0
 def bind(self, f, *args, **params):
     top_trace = jax_core.find_top_trace(args)
     trace_stack = jax_core.thread_local_state.trace_state.trace_stack
     level = (trace_stack.next_level(True)
              if top_trace is None else top_trace.level)
     params_tuple = tuple(params.items())
     f, env_trace_todo = jax_core.process_env_traces(
         f, self, level, params_tuple)
     if top_trace is None:
         with jax_core.new_sublevel():
             outs = self.impl(f, *args, **params)
     else:
         tracers = safe_map(top_trace.full_raise, args)
         if (isinstance(top_trace, batching.BatchTrace)
                 and self in custom_batch_rules):
             outs = custom_batch_rules[self](top_trace, f, tracers, params)
         else:
             if isinstance(top_trace, ad.JVPTrace):
                 prim = self.subcall('jvp')
             else:
                 prim = self
             outs = safe_map(
                 jax_core.full_lower,
                 top_trace.process_call(prim, f, tracers, params))
     return jax_core.apply_todos(env_trace_todo(), outs)
示例#3
0
 def bind(self, fun, jvp, *args):
   args = map(core.full_lower, args)
   top_trace = core.find_top_trace(args)
   fun, env_trace_todo1 = core.process_env_traces(
       fun, self, top_trace and top_trace.level, (), None)
   jvp, env_trace_todo2 = core.process_env_traces(
       jvp, self, top_trace and top_trace.level, (), None)
   tracers = map(top_trace.full_raise, args)  # type: ignore
   outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)  # type: ignore
   _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
   return _apply_todos(env_trace_todo, map(core.full_lower, outs))
示例#4
0
 def bind(self, call, *args, **params):
   # TODO(frostig,mattjj): This doesn't handle closures yet, which is
   # a bit involved. Closures are complicated by us binding `call`
   # twice in the JVP rule for custom transpose. The `env_trace_todo`
   # output by `process_env_traces` due to one of those two bindings
   # should be passable to the other, and need to be passed onward
   # since the second bind is deferred by partial eval (since it
   # typically receives unknowns)
   top_trace = core.find_top_trace(args)
   tracers = map(top_trace.full_raise, args)
   outs = top_trace.process_custom_transpose(self, call, tracers, **params)
   return outs
示例#5
0
 def bind(self, fun, fwd, bwd, *args, out_trees):
     args = map(core.full_lower, args)
     top_trace = core.find_top_trace(args)
     fun, env_trace_todo1 = core.process_env_traces(
         fun, self, top_trace and top_trace.level, (), None)
     fwd, env_trace_todo2 = core.process_env_traces(
         fwd, self, top_trace and top_trace.level, (), None)
     tracers = map(top_trace.full_raise, args)  # type: ignore
     with core.maybe_new_sublevel(top_trace):
         outs = top_trace.process_custom_vjp_call(self,
                                                  fun,
                                                  fwd,
                                                  bwd,
                                                  tracers,
                                                  out_trees=out_trees)
     _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1,
                                             env_trace_todo2)
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))
 def bind(self, fun, fwd, bwd, *args, out_trees):
   args = map(core.full_lower, args)
   top_trace = core.find_top_trace(args)
   fun, env_trace_todo1 = process_env_traces(
       fun, self, top_trace and top_trace.level, False)
   fwd, env_trace_todo2 = process_env_traces_fwd(
     fwd, top_trace and top_trace.level, out_trees)
   tracers = map(top_trace.full_raise, args)  # type: ignore
   bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
   outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
                                            out_trees=out_trees)
   fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
   if fst:
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))
   else:
     env_trace_todo, bwd_transform = env_trace_todo
     bwd = _apply_bwd_transform(bwd_transform, bwd)
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))