def test_without_apply_rng_multi_transform(self): def net(name): def f(x): p = base.get_parameter(name, [], init=jnp.zeros) return p + x return f def mod(): one = net(name='one') two = net(name='two') def init(x): z = one(x) return two(z) return init, (one, two) f = multi_transform.without_apply_rng( multi_transform.multi_transform_with_state(mod)) self.assertIsInstance(f, multi_transform.MultiTransformedWithState) params, state = f.init(None, jnp.ones(())) f.apply[0](params, state, jnp.ones(())) f.apply[1](params, state, jnp.ones(())) f = multi_transform.without_apply_rng( multi_transform.multi_transform(mod)) self.assertIsInstance(f, multi_transform.MultiTransformed) params = f.init(None, jnp.ones(())) f.apply[0](params, jnp.ones(())) f.apply[1](params, jnp.ones(()))
def test_tree_update_stats(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): """This should never update internal stats.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=False) init_fn, apply_fn_g = multi_transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_g(None, params_state, changed_params) ema_params2, params_state = apply_fn_g(None, params_state, changed_params) # ema_params should be the same as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) np.testing.assert_allclose(p1, p2) def h(x): """This will behave like normal.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=True) init_fn, apply_fn_h = multi_transform.without_apply_rng( transform.transform_with_state(h)) _, params_state = init_fn(None, params) params, params_state = apply_fn_h(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_h(None, params_state, changed_params) ema_params2, params_state = apply_fn_h(None, params_state, changed_params) # ema_params should be different as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = multi_transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = multi_transform.without_apply_rng( transform.transform(inner_fn)) inner_p = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, 0)) return vmap_inner(inner_p, x)
def build_and_init_stack(module_class): def stack_fn(x): mod = module_class() return layer_stack.layer_stack(1)(mod)(x) stack = multi_transform.without_apply_rng( transform.transform(stack_fn)) stack.init(jax.random.PRNGKey(1729), jnp.ones([5]))
def test_ema_is_identity_on_unchanged_data(self): def f(x): return moving_averages.ExponentialMovingAverage(0.5)(x) inp_value = 1.0 init_fn, apply_fn = multi_transform.without_apply_rng( transform.transform_with_state(f)) _, params_state = init_fn(None, inp_value) # The output should never change as long as the input doesn't. value = inp_value for _ in range(10): value, params_state = apply_fn(None, params_state, value) # Floating point error creeps up to 1e-7 (the default). np.testing.assert_allclose(inp_value, value, rtol=1e-6)
def test_ignore_regex(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x) init_fn, apply_fn = multi_transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # W should be the same! # ... but b should have changed! self.assertTrue( (changed_params["linear"]["b"] != ema_params["linear"]["b"]).all()) self.assertTrue( (changed_params["linear"]["w"] == ema_params["linear"]["w"]).all())