def test_tree_update_stats(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): """This should never update internal stats.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=False) init_fn, apply_fn_g = transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_g(None, params_state, changed_params) ema_params2, params_state = apply_fn_g(None, params_state, changed_params) # ema_params should be the same as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) np.testing.assert_allclose(p1, p2) def h(x): """This will behave like normal.""" return moving_averages.EMAParamsTree(0.2)(x, update_stats=True) init_fn, apply_fn_h = transform.without_apply_rng( transform.transform_with_state(h)) _, params_state = init_fn(None, params) params, params_state = apply_fn_h(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn_h(None, params_state, changed_params) ema_params2, params_state = apply_fn_h(None, params_state, changed_params) # ema_params should be different as ema_params2 with update_stats=False! for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_transparent_lift_with_state_nested(self): @transform.transform_with_state def inner(): w = base.get_state("w", [], init=jnp.zeros) w += 1 base.set_state("w", w) return w class Outer(module.Module): def __call__(self): lifted, updater = lift.transparent_lift_with_state(inner.init) params, state = lifted(None) out, state = inner.apply(params, state, None) updater.update(state) return out, state outer = transform.transform_with_state(lambda: Outer()()) # pylint: disable=unnecessary-lambda params, state = outer.init(None) self.assertEmpty(params) self.assertEqual(jax.tree_map(int, state), {"outer/~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) self.assertEqual(jax.tree_map(int, inner_state), {"~": { "w": expected }}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, {"outer/~": {"w": expected}})
def test_stateful_module(self): init_fn, apply_fn = transform.transform_with_state( lambda: CountingModule()()) # pylint: disable=unnecessary-lambda params, state = init_fn(None) self.assertEqual(state, {"counting_module": {"count": 0}}) _, state = apply_fn(params, state, None) self.assertEqual(state, {"counting_module": {"count": 10}})
def test_lift_with_state(self): @transform.transform_with_state def inner(): w = base.get_state("w", [], init=jnp.zeros) w += 1 base.set_state("w", w) return w def outer(): lifted, updater = lift.lift_with_state(inner.init) params, state = lifted(None) self.assertEmpty(params) out, state = inner.apply(params, state, None) updater.update(state) return out, state outer = transform.transform_with_state(outer) params, state = outer.init(None) self.assertEmpty(params) self.assertEqual(jax.tree_map(int, state), {"lifted/~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) self.assertEqual(jax.tree_map(int, inner_state), {"~": { "w": expected }}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, {"lifted/~": {"w": expected}})
def test_eval_shape(self): def some_shape_changing_fun(x): return x[0, :] def f(x): m = CountingModule(op=some_shape_changing_fun) # state is not changed in this call out_shape_struct = stateful.eval_shape(m, x) return m(x), out_shape_struct f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) in_shape = (10, 10) x = jnp.ones(in_shape) params, state = f.init(key, x) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) (out, shape_struct), state = f.apply(params, state, key, x) # Count is only advanced once np.testing.assert_allclose(state["counting_module"]["count"], 1, rtol=1e-4) np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4) self.assertEqual(shape_struct.shape, (in_shape[1], ))
def test_get_state_no_init(self): _, apply_fn = transform.transform_with_state( lambda: base.get_state("i")) for i in range(10): state_in = {"~": {"i": i}} _, state_out = apply_fn({}, state_in, None) self.assertEqual(state_in, state_out)
def test_without_state_raises_if_state_used_on_apply(self): f = lambda: base.set_state("~", 1) f = transform.without_state(transform.transform_with_state(f)) rng = jax.random.PRNGKey(42) with self.assertRaisesRegex(ValueError, "use.*transform_with_state"): params = f.init(rng) f.apply(params, rng)
def test_simple_training_cross_replica_axis_index_groups(self): ldc = jax.local_device_count() if ldc < 2: self.skipTest("Cross-replica test requires at least 2 devices.") num_groups = ldc // 2 num_group_devices = ldc // num_groups # for 8 devices this produces [[0, 1], [2, 3], [4, 5], [6, 7]] groups. groups = np.arange(ldc).reshape(num_groups, num_group_devices).tolist() def f(x, is_training=True): return batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9, cross_replica_axis="i", cross_replica_axis_index_groups=groups, )(x, is_training=is_training) f = transform.transform_with_state(f) inputs = np.arange(ldc * 4).reshape(ldc, 4).astype(np.float32) key = np.broadcast_to(jax.random.PRNGKey(42), (ldc, 2)) params, state = jax.pmap(f.init, axis_name="i")(key, inputs) result, _ = jax.pmap(f.apply, axis_name="i")(params, state, key, inputs) expected = np.empty_like(inputs) for g in range(num_groups): group_inputs = inputs[num_group_devices*g:num_group_devices*(g + 1)] group_mean = np.mean(group_inputs, axis=0) group_std = np.std(group_inputs, axis=0) + 1e-10 group_inputs = (group_inputs - group_mean) / group_std expected[num_group_devices*g:num_group_devices*(g + 1)] = group_inputs np.testing.assert_array_almost_equal(result, expected)
def test_vmap(self): def g(x): return CountingModule()(x) def f(x): return stateful.vmap(g)(x) f = transform.transform_with_state(f) x = jnp.ones([4]) + 1 params, state = f.init(None, x) # State should not be mapped. self.assertEmpty(params) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 0) # The output should be mapped but state should not be. y, state = f.apply(params, state, None, x) self.assertEqual(y.shape, (4,)) np.testing.assert_allclose(y, x ** 2) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 1)
def test_argspec(self): init_fn, apply_fn = transform.transform_with_state(lambda: None) init_fn_spec = inspect.getfullargspec(init_fn) apply_fn_spec = inspect.getfullargspec(apply_fn) self.assertEqual(init_fn_spec.args, ["rng"]) self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])
def wrapper(*a, **k): """Runs init and apply of f.""" rng = random.PRNGKey(seed) if seed is not None else None transformed = transform.transform_with_state(lambda: f(*a, **k)) params, state = transformed.init(rng) if run_apply: transformed.apply(params, state, rng)
def test_scan_with_state(self, unroll_length): def f(xs): m = CountingModule() def sf(c, x): self.assertEqual(c, ()) return c, m(x) _, ys = stateful.scan(sf, (), xs) return ys f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) xs = jnp.arange(unroll_length) params, state = f.init(key, xs) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) ys, state = f.apply(params, state, key, xs) np.testing.assert_allclose(state["counting_module"]["count"], unroll_length, rtol=1e-4) np.testing.assert_allclose(ys, xs**2, rtol=1e-4)
def test_invalid_rng_state(self): f = transform.transform_with_state(lambda: None) with self.assertRaisesRegex( ValueError, "Init must be called with an RNG as the first argument"): f.init("nonsense") with self.assertRaisesRegex( ValueError, "Apply must be called with an RNG as the third argument"): f.apply({}, {"x": {}}, "nonsense")
def test_get_state_no_shape_raises(self): init_fn, apply_fn = transform.transform_with_state( lambda: base.get_state("i", init=jnp.zeros)) with self.assertRaisesRegex(ValueError, "provide shape and dtype"): init_fn(None) state = params = {"~": {}} with self.assertRaisesRegex(ValueError, "provide shape and dtype"): apply_fn(params, state, None)
def test_get_state_no_init_raises(self): init_fn, apply_fn = transform.transform_with_state( lambda: base.get_state("i")) with self.assertRaisesRegex(ValueError, "set an init function"): init_fn(None) state = params = {"~": {}} with self.assertRaisesRegex(ValueError, "set an init function"): apply_fn(params, state, None)
def test_unused_updater(self): def f() -> lift.LiftWithStateUpdater: f = transform.transform_with_state(lambda: None) return lift.lift_with_state(f.init)[1] f = transform.transform_with_state(f) with self.assertRaisesRegex(ValueError, "StateUpdater.*must be used"): f.init(None)
def test_named_call(self): def f(x): return stateful.named_call(SquareModule(), name="square")(x) x = jnp.array(2.) rng = jax.random.PRNGKey(42) init, apply = transform.transform_with_state(f) params, state = init(rng, x) y, state = jax.jit(apply)(params, state, rng, x) self.assertEqual(y, x ** 2)
def wrapper(*a, **k): """Runs init and apply of f.""" rng = random.PRNGKey(seed) if seed is not None else None init, apply = transform.transform_with_state(lambda: f(*a, **k)) if jax_transform: init, apply = map(jax_transform, (init, apply)) params, state = init(rng) if run_apply: out, state = apply(params, state, rng) return out
def test_without_apply_rng_output_type(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w f = transform.without_apply_rng(transform.transform_with_state(f)) self.assertIsInstance(f, transform.TransformedWithState) f = transform.without_apply_rng(transform.transform(f)) self.assertIsInstance(f, transform.Transformed)
def test_empty_lift_with_state(self, ignore_update): f = transform.transform_with_state(lambda: None) init_fn, updater = lift.lift_with_state(f.init) params, state = init_fn(None) self.assertEmpty(params) self.assertEmpty(state) if ignore_update: updater.ignore_update() else: updater.update({})
def test_grad_and_jit(self): def f(x): g = stateful.grad(SquareModule())(x) return g x = jnp.array(3.) f = transform.transform_with_state(f) params, state = jax.jit(f.init)(None, x) g, state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
def test_without_state(self): def f(): w = base.get_parameter("w", [], init=jnp.zeros) return w init_fn, apply_fn = transform.without_state( transform.transform_with_state(f)) params = init_fn(None) out = apply_fn(params, None) self.assertEqual(out, 0)
def testEmaCrossReplica(self): embedding_dim = 6 batch_size = 16 inputs = np.random.rand(jax.local_device_count(), batch_size, embedding_dim) embeddings = {} perplexities = {} for axis_name in [None, 'i']: def my_function(x, axis_name): decay = np.array(0.9, dtype=np.float32) vqvae_module = vqvae.VectorQuantizerEMA( embedding_dim=embedding_dim, num_embeddings=7, commitment_cost=0.5, decay=decay, cross_replica_axis=axis_name, dtype=jnp.float32) outputs = vqvae_module(x, is_training=True) return vqvae_module.embeddings, outputs['perplexity'] vqvae_f = transform.transform_with_state( functools.partial(my_function, axis_name=axis_name)) rng = jax.random.PRNGKey(42) rng = jnp.broadcast_to(rng, (jax.local_device_count(), rng.shape[0])) params, state = jax.pmap(vqvae_f.init, axis_name='i')(rng, inputs) update_fn = jax.pmap(vqvae_f.apply, axis_name='i') for _ in range(10): outputs, state = update_fn(params, state, None, inputs) embeddings[axis_name], perplexities[axis_name] = outputs # In the single-device case, specifying a cross_replica_axis should have # no effect. Otherwise, it should! if jax.device_count() == 1: # Have to use assert_allclose here rather than checking exact matches to # make the test pass on GPU, presumably because of nondeterministic # reductions. np.testing.assert_allclose(embeddings[None], embeddings['i'], rtol=1e-6, atol=1e-6) np.testing.assert_allclose(perplexities[None], perplexities['i'], rtol=1e-6, atol=1e-6) else: self.assertFalse((embeddings[None] == embeddings['i']).all()) self.assertFalse((perplexities[None] == perplexities['i']).all())
def test_cond(self): def f(x): mod = SquareModule() return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1)) f = transform.transform_with_state(f) for x, y in ((1, 4), (2, 4), (3, 16)): x, y = map(jnp.array, (x, y)) params, state = f.init(None, x) out, state = f.apply(params, state, None, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y)
def test_without_state_raises_if_state_used(self): def f(): for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) return count init_fn, _ = transform.without_state(transform.transform_with_state(f)) with self.assertRaisesRegex(ValueError, "use.*transform_with_state"): init_fn(None)
def test_updater_used_in_different_inner_transform(self, updater_fn): def f(): g = transform.transform_with_state(lambda: None) _, updater = lift.lift_with_state(g.init) transform.transform_with_state(lambda: updater_fn(updater)).init( None) f = transform.transform_with_state(f) with self.assertRaisesRegex( ValueError, "must be used within the same call to init/apply"): f.init(None)
def test_stateful(self): def f(): for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) return count init_fn, apply_fn = transform.transform_with_state(f) params, state = init_fn(None) self.assertEqual(state, {"~": {"count": 0}}) _, state = apply_fn(params, state, None) self.assertEqual(state, {"~": {"count": 10}})
def test_switch(self): def f(i, x): mod = SquareModule() branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)] return stateful.switch(i, branches, x) f = transform.transform_with_state(f) for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)): i, x, y = map(jnp.array, (i, x, y)) params, state = f.init(None, i, x) out, state = f.apply(params, state, None, i, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y)
def test_set_then_get(self): def net(): base.set_state("i", 1) return base.get_state("i") init_fn, apply_fn = transform.transform_with_state(net) params, state = init_fn(None) self.assertEqual(state, {"~": {"i": 1}}) for i in range(10): state_in = {"~": {"i": i}} y, state_out = apply_fn(params, state_in, None) self.assertEqual(y, 1) self.assertEqual(state_out, {"~": {"i": 1}})