def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_): const_cells, incells = jax_util.split_list(incells, [num_consts]) env, state = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells, incells, outcells) # pytype: disable=wrong-arg-types new_incells = [env.read(invar) for invar in jaxpr.invars] new_outcells = [env.read(outvar) for outvar in jaxpr.outvars] return const_cells + new_incells, new_outcells, state
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_): const_cells, incells = jax_util.split_list(incells, [num_consts]) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells, incells, outcells) new_incells = [env.read(invar) for invar in jaxpr.invars] new_outcells = [env.read(outvar) for outvar in jaxpr.outvars] return const_cells + new_incells, new_outcells, None
def log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells): """Runs log_prob propagation on a Jaxpr.""" def reducer(env, eqn, curr_log_prob, new_log_prob): if (isinstance(curr_log_prob, FailedLogProb) or isinstance(new_log_prob, FailedLogProb)): # If `curr_log_prob` is `None` that means we were unable to compute # a log_prob elsewhere, so the propagate failed. return failed_log_prob if eqn.primitive in log_prob_registry and new_log_prob is None: # We are unable to compute a log_prob for this primitive. return failed_log_prob if new_log_prob is not None: cells = [env.read(var) for var in eqn.outvars] ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()]) return curr_log_prob + new_log_prob + ildjs return curr_log_prob # Re-use the InverseAndILDJ propagation but silently fail instead of # erroring when we hit a primitive we can't invert. We accumulate the log # probability values using the propagater state. _, final_log_prob = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr, constcells, flat_incells, flat_outcells, reducer=reducer, initial_state=0.) if final_log_prob is failed_log_prob: raise ValueError('Cannot compute log_prob of function.') return final_log_prob
def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_forward_avals = [ trace_util.get_shaped_aval(arg) for arg in flat_forward_args ] flat_incells = [ InverseAndILDJ.unknown(aval) for aval in flat_forward_avals ] flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(not flat_incell.top() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_vals, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_vals) if reduce_ildj: ildj_ = sum(np.sum(i) for i in flat_ildjs) else: ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(forward_args) == 1: vals = vals[0] ildj_ = ildj_ if reduce_ildj else ildj_[0] return vals, ildj_
def test_should_propagate_accumulated_values_in_chain_function(self): def f(x): return np.exp(x) + 2. jaxpr, _ = trace_util.stage(f)(4.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(ILDJ.new, (4., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.ildj, -np.log(2.))
def test_should_propagate_accumulated_values_in_one_op_function(self): def f(x): return np.exp(x) jaxpr, _ = trace_util.stage(f)(2.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [unknown] * len(jaxpr.invars), list(map(ILDJ.new, (2., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.ildj, -np.log(2.))
def test_should_propagate_to_invars_for_one_op_function(self): def f(x): return np.exp(x) jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (1., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_should_propagate_forward_and_backward(self): def f(x, y): return x + 1., np.exp(x + 1.) + y jaxpr, _ = trace_util.stage(f)(0., 2.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (0., 2.)))) invals = [env[invar].val for invar in jaxpr.invars] onp.testing.assert_allclose(invals, (-1., 1.))
def test_propagation_should_not_reach_invars(self): def f(x): del x return 2. jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (1., )))) self.assertTrue(env.read(jaxpr.invars[0]).bottom())
def test_propagate_through_jit(self): def f(x): return jax.jit(np.exp)(x) + 2. jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_should_propagate_to_invars_for_chain_function(self): def f(x): return 2. + np.exp(x) jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_correct_inverse_for_identity_function(self): def f(x): return x jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (1., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 1.)
def test_propagation_should_not_reach_invars(self): def f(x): del x return 2. jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (1., )))) self.assertNotIn(jaxpr.invars[0], env)
def test_propagate_through_jit(self): def f(x): return jax.jit(np.exp)(x) + 2. jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.) self.assertLen(env.subenvs, 1)
def test_propagate_should_accumulate_state(self): def f(x): return np.exp(x) + 2. jaxpr, _ = trace_util.stage(f)(4.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals _, state = propagate( ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(ILDJ.new, (4., ))), reducer=lambda env, eqn, count, next: count + next, initial_state=0) self.assertEqual(state, 2)
def wrapped(sample, *args, **kwargs): """Function wrapper that takes in log_prob arguments.""" # Trace the function using a random seed jaxpr, _ = trace_util.stage(f)(random.PRNGKey(0), *args, **kwargs) flat_outargs, _ = tree_util.tree_flatten(sample) flat_inargs, _ = tree_util.tree_flatten(args) constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals] flat_incells = [unknown ] + [InverseAndILDJ.new(val) for val in flat_inargs] flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs] # Re-use the InverseAndILDJ propagation but silently fail instead of # erroring when we hit a primitive we can't invert. env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr, constcells, flat_incells, flat_outcells) # Traverse the resulting environment, looking for primitives that have # registered log_probs. final_log_prob = _accumulate_log_probs(env) return final_log_prob
def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_incells = [unknown] * len(flat_forward_args) flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(flat_incell.is_unknown() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_cells, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_cells) ildjs = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(trace_args) == 1: vals, ildjs = vals[0], ildjs[0] return vals, ildjs
def initial_ildj(incells, outcells, *, jaxpr, **_): env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, [], incells, outcells) new_incells = [env.read(invar) for invar in jaxpr.invars] new_outcells = [env.read(outvar) for outvar in jaxpr.outvars] return new_incells, new_outcells, None