def test_should_not_inline_calls_without_variables(self): def inline_call(x): x = call(lambda x: x + 1)(x) x = variable(x, name='x') return call(lambda x: x + 1)(x) init, apply = unzip_variable(inline_call)(1.) self.assertDictEqual(init(1.), {'x': 2.}) init_jaxpr = trace_util.stage(init)(1.)[0] self.assertIn(call_p, {eqn.primitive for eqn in init_jaxpr.jaxpr.eqns}) apply_jaxpr = trace_util.stage(apply)(init(1.))[0] self.assertIn(call_p, {eqn.primitive for eqn in apply_jaxpr.jaxpr.eqns})
def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_forward_avals = [ trace_util.get_shaped_aval(arg) for arg in flat_forward_args ] flat_incells = [ InverseAndILDJ.unknown(aval) for aval in flat_forward_avals ] flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(not flat_incell.top() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_vals, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_vals) if reduce_ildj: ildj_ = sum(np.sum(i) for i in flat_ildjs) else: ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(forward_args) == 1: vals = vals[0] ildj_ = ildj_ if reduce_ildj else ildj_[0] return vals, ildj_
def wrapped(*args, **kwargs): closed_jaxpr, (in_tree, out_tree) = trace_util.stage(f, dynamic=False)(*args, **kwargs) jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals expressions = jaxpr_to_expressions(jaxpr) return (BoundExpression(expressions, dict(zip(map(str, jaxpr.constvars), consts))), tuple(map(str, jaxpr.invars)), (in_tree, out_tree))
def wrapped(*args, **kwargs): """Runs a function and binds it to a call primitive.""" jaxpr, (in_tree, out_tree) = trace_util.stage(f, dynamic=True)( *args, **kwargs) flat_args = tree_util.tree_leaves(args) outs = prim.bind(*it.chain(jaxpr.literals, flat_args), jaxpr=jaxpr.jaxpr, in_tree=in_tree, out_tree=out_tree, num_consts=len(jaxpr.literals), **params) return tree_util.tree_unflatten(out_tree, outs)
def forward(state: Value, *args: Value, **kwargs) -> Tuple[Value, Value]: # First, trace the function into a JAXpr. closed_jaxpr, (_, out_tree) = trace_util.stage( f, dynamic=True)(*args, **kwargs) flat_args = tree_util.tree_leaves(args) # Interpret the JAXpr according to `handlers`. out, state = eval_jaxpr_with_state(closed_jaxpr.jaxpr, handlers, closed_jaxpr.literals, state, *flat_args) return tree_util.tree_unflatten(out_tree, out), state
def test_should_propagate_accumulated_values_in_one_op_function(self): def f(x): return np.exp(x) jaxpr, _ = trace_util.stage(f)(2.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [unknown] * len(jaxpr.invars), list(map(ILDJ.new, (2., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.ildj, -np.log(2.))
def test_should_propagate_accumulated_values_in_chain_function(self): def f(x): return np.exp(x) + 2. jaxpr, _ = trace_util.stage(f)(4.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(ILDJ.new, (4., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.ildj, -np.log(2.))
def test_should_propagate_to_invars_for_one_op_function(self): def f(x): return np.exp(x) jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (1., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_should_propagate_forward_and_backward(self): def f(x, y): return x + 1., np.exp(x + 1.) + y jaxpr, _ = trace_util.stage(f)(0., 2.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (0., 2.)))) invals = [env[invar].val for invar in jaxpr.invars] onp.testing.assert_allclose(invals, (-1., 1.))
def test_propagation_should_not_reach_invars(self): def f(x): del x return 2. jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (1., )))) self.assertTrue(env.read(jaxpr.invars[0]).bottom())
def test_propagate_through_jit(self): def f(x): return jax.jit(np.exp)(x) + 2. jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_should_propagate_to_invars_for_chain_function(self): def f(x): return 2. + np.exp(x) jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.)
def test_correct_inverse_for_identity_function(self): def f(x): return x jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env, _ = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(Inverse.new, (1., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 1.)
def test_propagation_should_not_reach_invars(self): def f(x): del x return 2. jaxpr, _ = trace_util.stage(f)(1.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (1., )))) self.assertNotIn(jaxpr.invars[0], env)
def test_propagate_through_jit(self): def f(x): return jax.jit(np.exp)(x) + 2. jaxpr, _ = trace_util.stage(f)(3.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals env = propagate(Inverse, inverse_rules, jaxpr, list(map(Inverse.new, consts)), [unknown] * len(jaxpr.invars), list(map(Inverse.new, (3., )))) inval = env[jaxpr.invars[0]] self.assertEqual(inval.val, 0.) self.assertLen(env.subenvs, 1)
def wrapped(sample, *args, **kwargs): """Function wrapper that takes in log_prob arguments.""" # Trace the function using a random seed dummy_seed = random.PRNGKey(0) jaxpr, _ = trace_util.stage(f, dynamic=False)(dummy_seed, *args, **kwargs) flat_outargs, _ = tree_util.tree_flatten(sample) flat_inargs, _ = tree_util.tree_flatten(args) constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals] flat_incells = [ InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed)) ] + [InverseAndILDJ.new(val) for val in flat_inargs] flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs] return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule: has_init_key = kwargs_util.check_in_kwargs(f, init_keyword) if not has_init_key: def init_f(init_key, *args, **kwargs): del init_key, args, kwargs return {} def cau_f(variables, *args, **kwargs): return f(*args, **kwargs), variables else: with unzip.new_custom_rules(custom_unzip_rules): def fun(init_key, *args, **kwargs): kwargs = {**kwargs, init_keyword: init_key} return f(*args, **kwargs) init_f, apply_f = unzip.unzip(fun, tag=module.VARIABLE)(init_key, *args, **kwargs) cau_f = functools.partial( harvest.harvest(apply_f, tag=module.ASSIGN), {}) if name is not None: init_f = harvest.nest(init_f, scope=name) cau_f = harvest.nest(cau_f, scope=name) variables = init_f(init_key) cau_jaxpr, (in_tree, out_tree) = trace_util.stage(cau_f, dynamic=True)(variables, *args, **kwargs) if name is None: variables = { k: module.variable(val, name=k, key=init_key) for k, val in variables.items() } return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) else: return module.variable(FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name), name=name, key=init_key)
def test_propagate_should_accumulate_state(self): def f(x): return np.exp(x) + 2. jaxpr, _ = trace_util.stage(f)(4.) jaxpr, consts = jaxpr.jaxpr, jaxpr.literals _, state = propagate( ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)), [Inverse.unknown(var.aval) for var in jaxpr.invars], list(map(ILDJ.new, (4., ))), reducer=lambda env, eqn, count, next: count + next, initial_state=0) self.assertEqual(state, 2)
def wrapped(sample, *args, **kwargs): """Function wrapper that takes in log_prob arguments.""" # Trace the function using a random seed jaxpr, _ = trace_util.stage(f)(random.PRNGKey(0), *args, **kwargs) flat_outargs, _ = tree_util.tree_flatten(sample) flat_inargs, _ = tree_util.tree_flatten(args) constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals] flat_incells = [unknown ] + [InverseAndILDJ.new(val) for val in flat_inargs] flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs] # Re-use the InverseAndILDJ propagation but silently fail instead of # erroring when we hit a primitive we can't invert. env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr, constcells, flat_incells, flat_outcells) # Traverse the resulting environment, looking for primitives that have # registered log_probs. final_log_prob = _accumulate_log_probs(env) return final_log_prob
def random_variable_batching_rule(args, dims, *, num_consts, batch_ndims, jaxpr, **params): """Batching (vmap) rule for the `random_variable` primitive.""" old_consts = args[:num_consts] args, dims = args[num_consts:], dims[num_consts:] def _run(*args): return random_variable_p.impl(*it.chain(old_consts, args), num_consts=len(old_consts), jaxpr=jaxpr, batch_ndims=batch_ndims, **params) run = jax.vmap(_run, in_axes=dims, out_axes=0) closed_jaxpr, _ = trace_util.stage(run, dynamic=True)(*args) new_jaxpr, new_consts = closed_jaxpr.jaxpr, closed_jaxpr.literals result = random_variable_p.bind(*it.chain(new_consts, args), num_consts=len(new_consts), jaxpr=new_jaxpr, batch_ndims=batch_ndims + 1, **params) return result, (0,) * len(result)
def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_incells = [unknown] * len(flat_forward_args) flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(flat_incell.is_unknown() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_cells, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_cells) ildjs = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(trace_args) == 1: vals, ildjs = vals[0], ildjs[0] return vals, ildjs
def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule: has_init_key = kwargs_util.check_in_kwargs(f, init_keyword) if not has_init_key: def init_f(init_key, *args, **kwargs): del init_key, args, kwargs return {} def cau_f(variables, *args, **kwargs): return f(*args, **kwargs), variables else: def f_(init_key, *args, **kwargs): return f(*args, **kwargs, init_key=init_key) def init_f(init_key, *args, **kwargs): return harvest.reap( f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs) def apply_f(variables, *args, **kwargs): return harvest.plant( f_, tag=module.VARIABLE)(variables, random.PRNGKey(0), *args, **kwargs) cau_f = functools.partial(harvest.harvest(apply_f, tag=module.ASSIGN), {}) variables = init_f(init_key, *args, **kwargs) cau_jaxpr, (in_tree, out_tree) = trace_util.stage( cau_f, dynamic=True)(variables, *args, **kwargs) if name is None: variables = { k: module.variable(val, name=k, key=init_key) for k, val in variables.items() } return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) else: mod = FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name) if variables: return module.variable(mod, name=name, key=init_key) return mod