def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_forward_avals = [ trace_util.get_shaped_aval(arg) for arg in flat_forward_args ] flat_incells = [ InverseAndILDJ.unknown(aval) for aval in flat_forward_avals ] flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(not flat_incell.top() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_vals, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_vals) if reduce_ildj: ildj_ = sum(np.sum(i) for i in flat_ildjs) else: ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(forward_args) == 1: vals = vals[0] ildj_ = ildj_ if reduce_ildj else ildj_[0] return vals, ildj_
def aval(self): pv, const = self.pval if isinstance(pv, jax_core.AbstractValue): return pv elif pv is None: return trace_util.get_shaped_aval(const) else: raise TypeError(pv) return self.val
def instantiate_const_abstracted(self, tracer): pv, const = tracer.pval if isinstance(pv, jax_core.AbstractValue): return tracer elif pv is None: aval = abstract_arrays.raise_to_shaped( trace_util.get_shaped_aval(const), onp.isscalar(const)) return UnzipTracer(self, pe.PartialVal.unknown(aval), pe.ConstVar(const), tracer.is_key()) else: raise TypeError(pv)
def wrapped(sample, *args, **kwargs): """Function wrapper that takes in log_prob arguments.""" # Trace the function using a random seed dummy_seed = random.PRNGKey(0) jaxpr, _ = trace_util.stage(f, dynamic=False)(dummy_seed, *args, **kwargs) flat_outargs, _ = tree_util.tree_flatten(sample) flat_inargs, _ = tree_util.tree_flatten(args) constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals] flat_incells = [ InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed)) ] + [InverseAndILDJ.new(val) for val in flat_inargs] flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs] return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
def wrapped(sample, *args, **kwargs): """Function wrapper that takes in log_prob arguments.""" # Trace the function using a random seed dummy_seed = random.PRNGKey(0) jaxpr, _ = trace_util.stage(f)(dummy_seed, *args, **kwargs) flat_outargs, _ = tree_util.tree_flatten(sample) flat_inargs, _ = tree_util.tree_flatten(args) constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals] flat_incells = [ InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed)) ] + [InverseAndILDJ.new(val) for val in flat_inargs] flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs] # Re-use the InverseAndILDJ propagation but silently fail instead of # erroring when we hit a primitive we can't invert. env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr, constcells, flat_incells, flat_outcells) # Traverse the resulting environment, looking for primitives that have # registered log_probs. final_log_prob = _accumulate_log_probs(env) return final_log_prob
def new_instantiated_const(self, val): return UnzipTracer( self, pe.PartialVal.unknown(trace_util.get_shaped_aval(val)), pe.ConstVar(val), True)
def new_instantiated_literal(self, val): return UnzipTracer( self, pe.PartialVal.unknown(trace_util.get_shaped_aval(val)), jax_core.Literal(val), True)
def new(cls, val): aval = trace_util.get_shaped_aval(val) return Inverse(aval, val)
def new(cls, val): aval = trace_util.get_shaped_aval(val) return ILDJ(aval, val, 0.)