Ejemplo n.º 1
0
Archivo: ad.py Proyecto: jbampton/jax
def jvp_subtrace_aux(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      assert x._trace.level < trace.level
  ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
  ans_tracers = map(trace.full_raise, ans)
  out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
  aux_primals = [core.full_lower(x.primal)
                 if isinstance(x, JVPTracer) and x._trace.level == trace.level
                 else x for x in aux]
  yield (out_primals, out_tangents), aux_primals
Ejemplo n.º 2
0
def bind(self, *args, **kwargs):
    """Like Primitive.bind, but finds the top trace even when no arguments are provided."""
    assert jax.core.skip_checks or all(isinstance(arg, Tracer)
                                       or valid_jaxtype(arg) for arg in args), args

    trace = _top_trace()
    main = find_top_trace(args).main
    dynamic = thread_local_state.trace_state.trace_stack.dynamic
    assert (jax.core.skip_checks or main is dynamic or main is trace.main), args

    tracers = map(trace.full_raise, args)
    out_tracer = trace.process_primitive(self, tracers, kwargs)
    return map(full_lower, out_tracer) if self.multiple_results else full_lower(out_tracer)
Ejemplo n.º 3
0
def _maybe_perturbed(x: Any) -> bool:
  # False if x can't represent an AD-perturbed value (i.e. a value
  # with a nontrivial tangent attached), up to heuristics, and True otherwise.
  # See https://github.com/google/jax/issues/6415 for motivation.
  x = core.full_lower(x)
  if not isinstance(x, core.Tracer):
    # If x is not a Tracer, it can't be perturbed.
    return False
  elif isinstance(x, pe.DynamicJaxprTracer):
    # If x is a DynamicJaxprTracer then we're staging out; differentiation could
    # happen later, but some types always have trivial tangents.
    vspace = x.aval.at_least_vspace()
    return not (vspace is core.abstract_token or
                getattr(vspace, 'dtype', None) is dtypes.float0)
  elif not isinstance(x, ad.JVPTracer):
    # If x is not a JVPTracer, recursively check its contents.
    return any(_maybe_perturbed(attr) for name, attr in x._contents())
  else:
    return True  # We can't be sure!
Ejemplo n.º 4
0
Archivo: jet.py Proyecto: nhanwei/jax
 def full_lower(self):
     if self.terms is zero_series or all(t is zero_term
                                         for t in self.terms):
         return core.full_lower(self.primal)
     else:
         return self
Ejemplo n.º 5
0
 def full_lower(self):
     if self.tail is None:
         return core.full_lower(self.head)
     else:
         return self
Ejemplo n.º 6
0
 def full_lower(self):
   if self.batch_dim is not_mapped:
     return core.full_lower(self.val)
   else:
     return self
Ejemplo n.º 7
0
Archivo: ad.py Proyecto: jbampton/jax
 def full_lower(self):
   if type(self.tangent) is Zero:
     return core.full_lower(self.primal)
   else:
     return self
Ejemplo n.º 8
0
 def full_lower(self):
   if self.is_pure():
     return core.full_lower(self.val)
   else:
     return self
Ejemplo n.º 9
0
 def full_lower(self):
     if self.is_pure():
         _, const = self.pval
         return jax_core.full_lower(const)
     return self