def test_tree_update_stats(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): """This should never update internal stats.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=False) init_fn, apply_fn_g = base.without_apply_rng( base.transform_with_state(g)) _, params_state = init_fn(None, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_g(None, params_state, changed_params) ema_params2, params_state = apply_fn_g(None, params_state, changed_params) # ema_params should be the same as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) np.testing.assert_allclose(p1, p2) def h(x): """This will behave like normal.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=True) init_fn, apply_fn_h = base.without_apply_rng( base.transform_with_state(h)) _, params_state = init_fn(None, params) params, params_state = apply_fn_h(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_h(None, params_state, changed_params) ema_params2, params_state = apply_fn_h(None, params_state, changed_params) # ema_params should be different as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_get_state_no_init_raises(self): init_fn, apply_fn = base.transform_with_state(lambda: base.get_state("i")) with self.assertRaisesRegex(ValueError, "set an init function"): init_fn(None) state = params = {"~": {}} with self.assertRaisesRegex(ValueError, "set an init function"): apply_fn(params, state, None)
def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = base.without_apply_rng(base.transform_with_state(inner_fn)) inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0)) return vmap_inner(inner_p, inner_s, x)[0]
def test_argspec(self): init_fn, apply_fn = base.transform_with_state(lambda: None) init_fn_spec = inspect.getfullargspec(init_fn) apply_fn_spec = inspect.getfullargspec(apply_fn) self.assertEqual(init_fn_spec.args, ["rng"]) self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = base.without_apply_rng( base.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_stateful_module(self): init_fn, apply_fn = base.transform_with_state(lambda: CountingModule() ()) # pylint: disable=unnecessary-lambda params, state = init_fn(None) self.assertEqual(state, {"counting_module": {"count": 0}}) _, state = apply_fn(params, state, None) self.assertEqual(state, {"counting_module": {"count": 10}})
def wrapper(*a, **k): """Runs init and apply of f.""" rng = random.PRNGKey(seed) if seed is not None else None transformed = base.transform_with_state(lambda: f(*a, **k)) params, state = transformed.init(rng) if run_apply: transformed.apply(params, state, rng)
def test_get_state_no_shape_raises(self): init_fn, apply_fn = base.transform_with_state( lambda: base.get_state("i", init=jnp.zeros)) with self.assertRaisesRegex(ValueError, "provide shape and dtype"): init_fn(None) state = params = {"~": {}} with self.assertRaisesRegex(ValueError, "provide shape and dtype"): apply_fn(params, state, None)
def test_invalid_rng_state(self): f = base.transform_with_state(lambda: None) with self.assertRaisesRegex( ValueError, "Init must be called with an RNG as the first argument"): f.init("nonsense") with self.assertRaisesRegex( ValueError, "Apply must be called with an RNG as the third argument"): f.apply({}, {"x": {}}, "nonsense")
def test_without_state(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w init_fn, apply_fn = base.without_state(base.transform_with_state(f)) params = init_fn(None) out = apply_fn(params, None) self.assertEqual(out, 0)
def test_grad_and_jit(self): def f(x): g = stateful.grad(SquareModule())(x) return g x = jnp.array(3.) f = base.transform_with_state(f) params, state = jax.jit(f.init)(None, x) g, state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
def test_without_state_raises_if_state_used(self): def f(): for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) return count init_fn, _ = base.without_state(base.transform_with_state(f)) with self.assertRaisesRegex(ValueError, "use.*transform_with_state"): init_fn(None)
def test_stateful(self): def f(): for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) return count init_fn, apply_fn = base.transform_with_state(f) params, state = init_fn(None) self.assertEqual(state, {"~": {"count": 0}}) _, state = apply_fn(params, state, None) self.assertEqual(state, {"~": {"count": 10}})
def test_cond(self): def f(x): mod = SquareModule() return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1)) f = base.transform_with_state(f) for x, y in ((1, 4), (2, 4), (3, 16)): x, y = map(jnp.array, (x, y)) params, state = f.init(None, x) out, state = f.apply(params, state, None, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y)
def test_set_then_get(self): def net(): base.set_state("i", 1) return base.get_state("i") init_fn, apply_fn = base.transform_with_state(net) params, state = init_fn(None) self.assertEqual(state, {"~": {"i": 1}}) for i in range(10): state_in = {"~": {"i": i}} y, state_out = apply_fn(params, state_in, None) self.assertEqual(y, 1) self.assertEqual(state_out, {"~": {"i": 1}})
def test_ema_is_identity_on_unchanged_data(self): def f(x): return moving_averages.ExponentialMovingAverage(0.5)(x) inp_value = 1.0 init_fn, apply_fn = base.without_apply_rng( base.transform_with_state(f)) _, params_state = init_fn(None, inp_value) # The output should never change as long as the input doesn't. value = inp_value for _ in range(10): value, params_state = apply_fn(None, params_state, value) # Floating point error creeps up to 1e-7 (the default). np.testing.assert_allclose(inp_value, value, rtol=1e-6)
def test_ema_naming_scheme(self): ema_name = "this_is_a_wacky_but_valid_name" linear_name = "so_is_this" def f(): return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2, name=ema_name)(x) init_fn, _ = base.transform_with_state(g) _, params_state = init_fn(None, params) expected_ema_states = [ "{}/{}__{}".format(ema_name, linear_name, s) for s in ["w", "b"] ] self.assertEqual(set(expected_ema_states), set(params_state.keys()))
def test_sn_naming_scheme(self): sn_name = "this_is_a_wacky_but_valid_name" linear_name = "so_is_this" def f(): return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6, 6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return spectral_norm.SNParamsTree(ignore_regex=".*b", name=sn_name)(x) init_fn, _ = base.transform_with_state(g) _, params_state = init_fn(random.PRNGKey(428), params) expected_sn_states = [ "{}/{}__{}".format(sn_name, linear_name, s) for s in ["w"] ] self.assertSameElements(expected_sn_states, params_state.keys())
def test_ignore_regex(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x) init_fn, apply_fn = base.without_apply_rng( base.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # W should be the same! # ... but b should have changed! self.assertTrue((changed_params.linear.b != ema_params.linear.b).all()) self.assertTrue((changed_params.linear.w == ema_params.linear.w).all())
def test_without_state_raises_if_state_used(self): init_fn, _ = base.without_state( base.transform_with_state(lambda: CountingModule()())) # pylint: disable=unnecessary-lambda with self.assertRaisesRegex(ValueError, "use.*transform_with_state"): init_fn(None)
def test_without_state(self): init_fn, apply_fn = base.without_state( base.transform_with_state(lambda: ScalarModule()())) # pylint: disable=unnecessary-lambda params = init_fn(None) out = apply_fn(params, None) self.assertEqual(out, 0)
def test_get_state_no_init(self): _, apply_fn = base.transform_with_state(lambda: base.get_state("i")) for i in range(10): state_in = {"~": {"i": i}} _, state_out = apply_fn({}, state_in, None) self.assertEqual(state_in, state_out)