def jvp_subtrace_aux(main, primals, tangents): trace = JVPTrace(main, core.cur_sublevel()) for x in list(primals) + list(tangents): if isinstance(x, Tracer): assert x._trace.level < trace.level ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} ans_tracers = map(trace.full_raise, ans) out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) aux_primals = [core.full_lower(x.primal) if isinstance(x, JVPTracer) and x._trace.level == trace.level else x for x in aux] yield (out_primals, out_tangents), aux_primals
def bind(self, *args, **kwargs): """Like Primitive.bind, but finds the top trace even when no arguments are provided.""" assert jax.core.skip_checks or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args trace = _top_trace() main = find_top_trace(args).main dynamic = thread_local_state.trace_state.trace_stack.dynamic assert (jax.core.skip_checks or main is dynamic or main is trace.main), args tracers = map(trace.full_raise, args) out_tracer = trace.process_primitive(self, tracers, kwargs) return map(full_lower, out_tracer) if self.multiple_results else full_lower(out_tracer)
def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. # See https://github.com/google/jax/issues/6415 for motivation. x = core.full_lower(x) if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. vspace = x.aval.at_least_vspace() return not (vspace is core.abstract_token or getattr(vspace, 'dtype', None) is dtypes.float0) elif not isinstance(x, ad.JVPTracer): # If x is not a JVPTracer, recursively check its contents. return any(_maybe_perturbed(attr) for name, attr in x._contents()) else: return True # We can't be sure!
def full_lower(self): if self.terms is zero_series or all(t is zero_term for t in self.terms): return core.full_lower(self.primal) else: return self
def full_lower(self): if self.tail is None: return core.full_lower(self.head) else: return self
def full_lower(self): if self.batch_dim is not_mapped: return core.full_lower(self.val) else: return self
def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) else: return self
def full_lower(self): if self.is_pure(): return core.full_lower(self.val) else: return self
def full_lower(self): if self.is_pure(): _, const = self.pval return jax_core.full_lower(const) return self