Example #1
0
 def wrapped(*args, **kwargs):
     """Function wrapper that takes in inverse arguments."""
     forward_args = trace_args if len(trace_args) else args
     jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args,
                                                              **kwargs)
     flat_forward_args, _ = tree_util.tree_flatten(forward_args)
     flat_args, _ = tree_util.tree_flatten(args)
     flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
     flat_forward_avals = [
         trace_util.get_shaped_aval(arg) for arg in flat_forward_args
     ]
     flat_incells = [
         InverseAndILDJ.unknown(aval) for aval in flat_forward_avals
     ]
     flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
     env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                               flat_constcells, flat_incells, flat_outcells)
     flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
     if any(not flat_incell.top() for flat_incell in flat_incells):
         raise ValueError('Cannot invert function.')
     flat_vals, flat_ildjs = jax_util.unzip2([
         (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
     ])
     vals = tree_util.tree_unflatten(in_tree, flat_vals)
     if reduce_ildj:
         ildj_ = sum(np.sum(i) for i in flat_ildjs)
     else:
         ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs)
     if len(forward_args) == 1:
         vals = vals[0]
         ildj_ = ildj_ if reduce_ildj else ildj_[0]
     return vals, ildj_
Example #2
0
 def aval(self):
     pv, const = self.pval
     if isinstance(pv, jax_core.AbstractValue):
         return pv
     elif pv is None:
         return trace_util.get_shaped_aval(const)
     else:
         raise TypeError(pv)
     return self.val
Example #3
0
 def instantiate_const_abstracted(self, tracer):
     pv, const = tracer.pval
     if isinstance(pv, jax_core.AbstractValue):
         return tracer
     elif pv is None:
         aval = abstract_arrays.raise_to_shaped(
             trace_util.get_shaped_aval(const), onp.isscalar(const))
         return UnzipTracer(self, pe.PartialVal.unknown(aval),
                            pe.ConstVar(const), tracer.is_key())
     else:
         raise TypeError(pv)
Example #4
0
 def wrapped(sample, *args, **kwargs):
   """Function wrapper that takes in log_prob arguments."""
   # Trace the function using a random seed
   dummy_seed = random.PRNGKey(0)
   jaxpr, _ = trace_util.stage(f, dynamic=False)(dummy_seed, *args, **kwargs)
   flat_outargs, _ = tree_util.tree_flatten(sample)
   flat_inargs, _ = tree_util.tree_flatten(args)
   constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals]
   flat_incells = [
       InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
   ] + [InverseAndILDJ.new(val) for val in flat_inargs]
   flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
   return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
Example #5
0
 def wrapped(sample, *args, **kwargs):
     """Function wrapper that takes in log_prob arguments."""
     # Trace the function using a random seed
     dummy_seed = random.PRNGKey(0)
     jaxpr, _ = trace_util.stage(f)(dummy_seed, *args, **kwargs)
     flat_outargs, _ = tree_util.tree_flatten(sample)
     flat_inargs, _ = tree_util.tree_flatten(args)
     constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals]
     flat_incells = [
         InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
     ] + [InverseAndILDJ.new(val) for val in flat_inargs]
     flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
     # Re-use the InverseAndILDJ propagation but silently fail instead of
     # erroring when we hit a primitive we can't invert.
     env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr,
                               constcells, flat_incells, flat_outcells)
     # Traverse the resulting environment, looking for primitives that have
     # registered log_probs.
     final_log_prob = _accumulate_log_probs(env)
     return final_log_prob
Example #6
0
 def new_instantiated_const(self, val):
     return UnzipTracer(
         self, pe.PartialVal.unknown(trace_util.get_shaped_aval(val)),
         pe.ConstVar(val), True)
Example #7
0
 def new_instantiated_literal(self, val):
     return UnzipTracer(
         self, pe.PartialVal.unknown(trace_util.get_shaped_aval(val)),
         jax_core.Literal(val), True)
Example #8
0
 def new(cls, val):
     aval = trace_util.get_shaped_aval(val)
     return Inverse(aval, val)
Example #9
0
 def new(cls, val):
     aval = trace_util.get_shaped_aval(val)
     return ILDJ(aval, val, 0.)