def test_basic_noncentering_parameterization_behaves_correctly(self): def random_normal_noncentering_rule(state, key, loc, scale): return [random_normal(key) * scale + loc], state rules = {random_normal_p: random_normal_noncentering_rule} def f(key): return random_normal(key, 2., 1.) noncenter = effect_handler.make_effect_handler(rules) noncentered_f = noncenter(f) # Programs should be semantically identical self.assertEqual( f(random.PRNGKey(0)), noncentered_f(None, random.PRNGKey(0))[0]) # We should be sampling from an isotropic normal in the noncentered variant. noncenter_jaxpr = jax.make_jaxpr(noncentered_f)(None, random.PRNGKey(0)) for eqn in noncenter_jaxpr.jaxpr.eqns: if eqn.primitive is random_normal_p: loc = eqn.invars[1].val scale = eqn.invars[2].val self.assertEqual(loc, 0.) self.assertEqual(scale, 1.)
def test_effect_handler_with_no_rules_should_be_identity(self): def f(key): return random_normal(key) transformation = effect_handler.make_effect_handler({}) f_out, state = transformation(f)(None, random.PRNGKey(0)) self.assertIs(state, None) self.assertEqual(f_out, f(random.PRNGKey(0)))
def test_effect_handler_can_override_primitive_behavior(self): def random_normal_deterministic_rule(state, key, *_): del key return [0.], state rules = {random_normal_p: random_normal_deterministic_rule} def f(key): return random_normal(key) make_deterministic = effect_handler.make_effect_handler(rules) deterministic_f = make_deterministic(f) f_out, state = deterministic_f(None, random.PRNGKey(0)) self.assertIs(state, None) self.assertEqual(f_out, 0.)
def test_effect_handler_correctly_maintains_python_structures(self): def random_normal_counter_rule(count, key, *_): return [random_normal(key)], count + 1 rules = {random_normal_p: random_normal_counter_rule} @jax.jit def f(key): k1, k2 = random.split(key) return dict(x=random_normal(k1), y=random_normal(k2)) make_counter = effect_handler.make_effect_handler(rules) counter_f = make_counter(f) f_out, count = counter_f(0, random.PRNGKey(0)) self.assertEqual(count, 2) self.assertEqual(f_out, f(random.PRNGKey(0)))
def test_correctly_updates_state_inside_call_primitive(self): def random_normal_counter_rule(count, key, *_): return [random_normal(key)], count + 1 rules = {random_normal_p: random_normal_counter_rule} @jax.jit def f(key): k1, k2 = random.split(key) return random_normal(k1) + random_normal(k2) make_counter = effect_handler.make_effect_handler(rules) counter_f = make_counter(f) f_out, count = counter_f(0, random.PRNGKey(0)) self.assertEqual(count, 2) self.assertEqual(f_out, f(random.PRNGKey(0)))
def test_effect_handler_correctly_updates_state(self): def random_normal_counter_rule(count, key, *_): return [random_normal(key)], count + 1 rules = {random_normal_p: random_normal_counter_rule} def f(key): k1, k2 = random.split(key) return random_normal(k1) + random_normal(k2) make_counter = effect_handler.make_effect_handler(rules) counter_f = make_counter(f) count = 0 f_out, count = counter_f(count, random.PRNGKey(0)) self.assertEqual(count, 2) self.assertEqual(f_out, f(random.PRNGKey(0))) f_out, count = counter_f(count, random.PRNGKey(1)) self.assertEqual(count, 4) self.assertEqual(f_out, f(random.PRNGKey(1)))