class ResnetTest(parameterized.TestCase): @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) @test_utils.transform_and_run def test_simple(self, resnet_v2, bottleneck): image = jnp.ones([2, 64, 64, 3]) model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) for is_training in (True, False): logits = model(image, is_training=is_training) self.assertEqual(logits.shape, (2, 10)) @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) def test_local_stats(self, resnet_v2, bottleneck): def forward_fn(image): model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) return model(image, is_training=False, test_local_stats=True) forward = transform.transform(forward_fn, apply_rng=True) rng = jax.random.PRNGKey(42) image = jnp.ones([2, 64, 64, 3]) params = forward.init(rng, image) logits = forward.apply(params, None, image) self.assertEqual(logits.shape, (2, 10)) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_block_list(self, list_length): block_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "blocks_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5}) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_channel_list(self, list_length): channel_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "channels_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet([1, 1, 1, 1], 10, { "decay_rate": 0.9, "eps": 1e-5 }, channels_per_group=channel_list)
class NumpyInputsTest(parameterized.TestCase): @test_utils.combined_named_parameters( descriptors.ALL_MODULES, test_utils.named_bools('np_inputs'), test_utils.named_bools('np_params'), test_utils.named_bools('close_over_params')) def test_numpy_and_jax_results_close( self, module_fn: ModuleFn, shape: Tuple[int, ...], dtype: jnp.dtype, np_params: bool, np_inputs: bool, close_over_params: bool, ): if not (np_params or np_inputs): self.skipTest('Pure JAX variants tested elsewhere') f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = jnp.ones(shape, dtype) params, state = f.init(rng, x) if close_over_params: apply_fn = functools.partial(f.apply, params, state) out, new_state = jax.jit(apply_fn)(rng, x) else: out, new_state = jax.jit(f.apply)(params, state, rng, x) if np_inputs: rng, x = jax.device_get((rng, x)) with self.subTest('init'): params2, state2 = f.init(rng, x) tree_assert_allclose(params, params2) tree_assert_allclose(state, state2) with self.subTest('apply'): if np_params: params, state = jax.device_get((params, state)) if close_over_params: apply_fn = functools.partial(f.apply, params, state) out2, new_state2 = jax.jit(apply_fn)(rng, x) else: out2, new_state2 = jax.jit(f.apply)(params, state, rng, x) tree_assert_allclose(out, out2) tree_assert_allclose(new_state, new_state2)
class JaxToTfTest(parameterized.TestCase): @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools("init"), TF_TRANSFORM, JAX_TRANSFORM) def test_convert( self, module_fn: ModuleFn, shape, dtype, init: bool, tf_transform, jax_transform, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(g) atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn), DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) get = lambda t: jax.tree_map(lambda x: x.numpy(), t) if init: init_jax = jax_transform(f.init) init_tf = tf_transform(jax2tf.convert(f.init)) jax.tree_multimap(assert_allclose, init_jax(rng, x), get(init_tf(rng, x))) else: params, state = f.init(rng, x) apply_jax = jax_transform(f.apply) apply_tf = tf_transform(jax2tf.convert(f.apply)) jax.tree_multimap(assert_allclose, apply_jax(params, state, rng, x), get(apply_tf(params, state, rng, x)))
class HaikuTransformsTest(parameterized.TestCase): @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) def test_hk_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, init: bool, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, jit=False): mod = module_fn() if jit: mod = hk.jit(mod) return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: jax.tree_multimap(assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) else: params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x, jit=True)) @test_utils.combined_named_parameters( # TODO(tomhennigan) Enable once grad for _scan_transpose implemented. set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES)) def test_hk_remat( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x)) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_profiler_name_scopes( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, name_scopes=False): hk.experimental.profiler_name_scopes(enabled=name_scopes) mod = module_fn() return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), f.apply(params, state, rng, x, name_scopes=True)) # TODO(lenamartens): flip to True when default changes hk.experimental.profiler_name_scopes(enabled=False) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_optimize_rng_use_under_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(hk.experimental.optimize_rng_use(g)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) params, state = jax.jit(f.init)(rng, x) jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x)) if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): # For stochastic modules just test apply runs. jax.device_get(jax.jit(f.apply)(params, state, rng, x)) else: jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x))
class SummariseTest(parameterized.TestCase): def test_empty(self): self.assertEmpty(get_summary(lambda: None)) def test_filters_ctor_only(self): f = lambda: IdentityModule() # NOTE: Just calling ctor. self.assertEmpty(get_summary(f)) @parameterized.parameters(*range(1, 5)) def test_one_row_per_method_call(self, num_calls): def f(): m = IdentityModule() for _ in range(num_calls): m(x) x = jnp.ones([]) invocations = get_summary(f) self.assertLen(invocations, num_calls) for invocation in invocations[1:]: self.assertEqual(invocations[0].context.method_name, invocation.context.method_name) @test_utils.combined_named_parameters(test_utils.named_bools("params"), test_utils.named_range( "num_elems", 8)) def test_params_or_state(self, params, num_elems): def cls(): for i in range(num_elems): g = base.get_parameter if params else base.get_state g(f"x{i}", [], init=jnp.zeros) f = lambda: basic.to_module(cls)(name="foo")() invocations = get_summary(f) invocation, = invocations details = invocation.module_details d = details.params if params else details.state self.assertEqual(list(d), [f"foo/x{i}" for i in range(num_elems)]) def test_jitted_f(self): witness = [] def f(x): witness.append(None) return basic.Linear(1)(x) f = transform.transform(f) rng = jax.random.PRNGKey(42) x = jnp.zeros([1, 1]) params = f.init(rng, x) del witness[:] # This layer of indirection (`g`) means summarise cannot unpack `f` and # strip our jit. jit_apply = jax.jit(f.apply) g = lambda params, x: jit_apply(params, None, x) for _ in range(2): g(params, x) # Warm up JIT. self.assertLen(witness, 1) summary = get_summary(g, params, x) self.assertLen(summary, 1)
class ResnetTest(parameterized.TestCase): @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) @test_utils.transform_and_run def test_simple(self, resnet_v2, bottleneck): image = jnp.ones([2, 64, 64, 3]) model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) for is_training in (True, False): logits = model(image, is_training=is_training) self.assertEqual(logits.shape, (2, 10)) @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) def test_local_stats(self, resnet_v2, bottleneck): def forward_fn(image): model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) return model(image, is_training=False, test_local_stats=True) forward = transform.transform(forward_fn) rng = jax.random.PRNGKey(42) image = jnp.ones([2, 64, 64, 3]) params = forward.init(rng, image) logits = forward.apply(params, None, image) self.assertEqual(logits.shape, (2, 10)) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_block_list(self, list_length): block_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "blocks_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5}) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_channel_list(self, list_length): channel_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "channels_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet([1, 1, 1, 1], 10, { "decay_rate": 0.9, "eps": 1e-5 }, channels_per_group=channel_list) @test_utils.combined_named_parameters( [(i, (getattr(resnet, i), n)) for i, n in zip(_RESNETS, _RESNET_NUM_PARAMS)], test_utils.named_bools("resnet_v2"), ) def test_num_params(self, resnet_class_and_num_params, resnet_v2): resnet_class, expected_num_params = resnet_class_and_num_params def model_func(img): model = resnet_class(1000, resnet_v2=resnet_v2) return model(img, is_training=True) model = hk.transform_with_state(model_func) image = jnp.ones([2, 64, 64, 3]) rng = jax.random.PRNGKey(0) params, _ = model.init(rng, image) num_params = sum( jnp.prod(p.shape).item() for p in jax.tree_leaves(params)) self.assertGreater(num_params, int(0.998 * expected_num_params)) self.assertLess(num_params, int(1.002 * expected_num_params)) @test_utils.combined_named_parameters( [(i, (getattr(resnet, i), p)) for i, p in zip(_RESNETS, _RESNET_HAS_PROJECTION)], test_utils.named_bools("resnet_v2"), ) @test_utils.transform_and_run def test_has_projection(self, resnet_class_and_has_projection, resnet_v2): resnet_class, has_projection = resnet_class_and_has_projection model = resnet_class(1000, resnet_v2=resnet_v2) for i, block_group in enumerate(model.block_groups): if i == 0: self.assertEqual(hasattr(block_group.blocks[0], "proj_conv"), has_projection) else: self.assertTrue(hasattr(block_group.blocks[0], "proj_conv")) for block in block_group.blocks[1:]: self.assertFalse(hasattr(block, "proj_conv"))
class HaikuTransformsTest(parameterized.TestCase): @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) def test_hk_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, init: bool, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, jit=False): mod = module_fn() if jit: mod = hk.jit(mod) return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: jax.tree_multimap(assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) else: params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x, jit=True)) @test_utils.combined_named_parameters( # TODO(tomhennigan) Enable once grad for _scan_transpose implemented. set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES)) def test_hk_remat( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) return jnp.mean(mod(x)) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x))
class HaikuTransformsTest(parameterized.TestCase): @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) def test_hk_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, init: bool, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, jit=False): mod = module_fn() if jit: mod = hk.jit(mod) return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: jax.tree_multimap(assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) else: params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x, jit=True)) @test_utils.combined_named_parameters( # TODO(tomhennigan) Enable once grad for _scan_transpose implemented. set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES)) def test_hk_remat( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x)) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_profiler_name_scopes( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): if not hasattr(xla.xb, 'parameter'): self.skipTest('Need Jaxlib version > 0.1.45') rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, name_scopes=False): hk.experimental.profiler_name_scopes(enabled=name_scopes) mod = module_fn() return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), f.apply(params, state, rng, x, name_scopes=True)) # TODO(lenamartens): flip to True when default changes hk.experimental.profiler_name_scopes(enabled=False)
class ResnetTest(parameterized.TestCase): @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) @test_utils.transform_and_run def test_simple(self, resnet_v2, bottleneck): image = jnp.ones([2, 64, 64, 3]) model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) for is_training in (True, False): logits = model(image, is_training=is_training) self.assertEqual(logits.shape, (2, 10)) @test_utils.combined_named_parameters(test_utils.named_bools("resnet_v2"), test_utils.named_bools("bottleneck")) def test_local_stats(self, resnet_v2, bottleneck): def forward_fn(image): model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) return model(image, is_training=False, test_local_stats=True) forward = transform.transform(forward_fn) rng = jax.random.PRNGKey(42) image = jnp.ones([2, 64, 64, 3]) params = forward.init(rng, image) logits = forward.apply(params, None, image) self.assertEqual(logits.shape, (2, 10)) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_block_list(self, list_length): block_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "blocks_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5}) @parameterized.parameters(3, 5) @test_utils.transform_and_run def test_error_incorrect_args_channel_list(self, list_length): channel_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "channels_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet([1, 1, 1, 1], 10, { "decay_rate": 0.9, "eps": 1e-5 }, channels_per_group=channel_list) @test_utils.combined_named_parameters( [(i, (getattr(resnet, i), n)) for i, n in zip(_RESNETS, _RESNET_NUM_PARAMS)], test_utils.named_bools("resnet_v2"), ) def test_num_params(self, resnet_class_and_num_params, resnet_v2): resnet_class, expected_num_params = resnet_class_and_num_params def model_func(img): model = resnet_class(1000, resnet_v2=resnet_v2) return model(img, is_training=True) model = hk.transform_with_state(model_func) image = jnp.ones([2, 64, 64, 3]) rng = jax.random.PRNGKey(0) params, _ = model.init(rng, image) num_params = sum( np.prod(p.shape).item() for p in jax.tree_leaves(params)) self.assertGreater(num_params, int(0.998 * expected_num_params)) self.assertLess(num_params, int(1.002 * expected_num_params)) @test_utils.combined_named_parameters( [(i, (getattr(resnet, i), p)) for i, p in zip(_RESNETS, _RESNET_HAS_PROJECTION)], test_utils.named_bools("resnet_v2"), ) @test_utils.transform_and_run def test_has_projection(self, resnet_class_and_has_projection, resnet_v2): resnet_class, has_projection = resnet_class_and_has_projection model = resnet_class(1000, resnet_v2=resnet_v2) for i, block_group in enumerate(model.block_groups): if i == 0: self.assertEqual(hasattr(block_group.blocks[0], "proj_conv"), has_projection) else: self.assertTrue(hasattr(block_group.blocks[0], "proj_conv")) for block in block_group.blocks[1:]: self.assertFalse(hasattr(block, "proj_conv")) @test_utils.combined_named_parameters( [(i, getattr(resnet, i)) for i in _RESNETS], test_utils.named_bools("resnet_v2"), ) def test_logits_config(self, resnet_class, resnet_v2): def model_func_logits_config_default(img): model = resnet_class(1000, resnet_v2=resnet_v2) return model(img, is_training=True) def model_func_logits_config_modified(img): model = resnet_class(1000, resnet_v2=resnet_v2, logits_config=dict(w_init=jnp.ones)) return model(img, is_training=True) image = jnp.ones([2, 64, 64, 3]) rng = jax.random.PRNGKey(0) model = hk.transform_with_state(model_func_logits_config_default) params, _ = model.init(rng, image) logits_keys = [k for k in params.keys() if "/logits" in k] self.assertLen(logits_keys, 1) # Check logits params are zeros w_logits = params[logits_keys[0]]["w"] np.testing.assert_allclose(jnp.zeros_like(w_logits), w_logits) model = hk.transform_with_state(model_func_logits_config_modified) params, _ = model.init(rng, image) # Check logits params are ones w_logits = params[logits_keys[0]]["w"] np.testing.assert_allclose(jnp.ones_like(w_logits), w_logits) @test_utils.combined_named_parameters( [(i, getattr(resnet, i)) for i in _RESNETS], ) @test_utils.transform_and_run def test_initial_conv_config(self, resnet_cls): config = dict(name="custom_name", output_channels=32, kernel_shape=(3, 3), stride=(1, 1), padding="VALID", with_bias=True) net = resnet_cls(1000, initial_conv_config=config) for key, value in config.items(): self.assertEqual(getattr(net.initial_conv, key), value)
class StatefulTest(parameterized.TestCase): @test_utils.transform_and_run def test_grad(self): x = jnp.array(3.) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.grad(jnp.square)(x) @test_utils.transform_and_run def test_value_and_grad(self): x = jnp.array(2.) y, g = stateful.value_and_grad(SquareModule())(x) self.assertEqual(y, x ** 2) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_value_and_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.value_and_grad(jnp.square)(x) @test_utils.transform_and_run def test_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) g, aux = stateful.grad(f, has_aux=True)(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o) @test_utils.transform_and_run def test_value_and_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x) self.assertEqual(y, jnp.power(x, 2)) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o) def test_grad_and_jit(self): def f(x): g = stateful.grad(SquareModule())(x) return g x = jnp.array(3.) f = transform.transform_with_state(f) params, state = jax.jit(f.init)(None, x) g, state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) def test_value_and_grad_and_jit(self): def f(x): y, g = stateful.value_and_grad(SquareModule())(x) return y, g x = jnp.array(3.) f = transform.transform_with_state(f) params, state = jax.jit(f.init)(None, x) (y, g), state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) @test_utils.transform_and_run def test_jit(self): mod = SquareModule() x = jnp.array(2) y = stateful.jit(mod)(x) self.assertEqual(y, x ** 2) def test_jit_no_transform(self): x = jnp.array(2) with self.assertRaises(ValueError, msg="Use jax.jit() instead"): stateful.jit(jnp.square)(x) @test_utils.transform_and_run def test_remat(self): forward, backward = [], [] callback = _callback_prim(lambda: forward.append(None), lambda: backward.append(None)) def test(remat): x = jnp.array(3.) mod = CountingModule() self.assertEqual(mod.count, 0) f = lambda x: callback(mod(x)) if remat: f = stateful.remat(f) y, g = stateful.value_and_grad(f)(x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) self.assertEqual(mod.count, 1) num_forward = len(forward) num_backward = len(backward) del forward[:], backward[:] return num_forward, num_backward # Sanity check. self.assertEqual(test(remat=True), test(remat=True)) self.assertEqual(test(remat=False), test(remat=False)) # NOTE: JAX does not guarantee to execute primitives once and only once for # a given function (we observe f=2,b=1 without remat and f=5,b=1 with # remat), but we do expect that JAX will execute our primitive forward at # least one more time with remat than without it. num_forward_remat, num_backward_remat = test(remat=True) num_forward_no_remat, num_backward_no_remat = test(remat=False) self.assertGreater(num_forward_remat, num_forward_no_remat) self.assertEqual(num_backward_remat, num_backward_no_remat) def test_remat_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.remat() instead"): stateful.remat(jnp.square)(x) @test_utils.combined_named_parameters( test_utils.named_bools("jax_remat"), test_utils.named_bools("inline_hk_remat")) def test_create_module_inside_remat(self, jax_remat, inline_hk_remat): log = [] def forward(x): def create_and_use_layer(x): m = SquareModule(name="layer") log.append(m.module_name) return m(x) if not inline_hk_remat: create_and_use_layer = stateful.remat(create_and_use_layer) for _ in range(2): if inline_hk_remat: x = stateful.remat(create_and_use_layer)(x) else: x = create_and_use_layer(x) return x def reset(): del log[:] self.assertEmpty(log) # Test forward. x = jnp.float32(3) forward = transform.transform_with_state(forward) params, state = forward.init(None, x) self.assertEqual(log, ["layer", "layer_1"]) reset() # Test backward. for _ in range(3): grad_fn = jax.grad(lambda x: forward.apply(params, state, None, x)[0]) if jax_remat: grad_fn = jax.remat(grad_fn) self.assertEqual(int(grad_fn(x)), int(4 * (x ** 3))) self.assertEqual(log, ["layer", "layer_1"]) reset() @parameterized.parameters(True, False) def test_cond(self, single_arg): def f(x): mod = SquareModule() if single_arg: return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x) else: return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1)) f = transform.transform_with_state(f) for x, y in ((1, 4), (2, 4), (3, 16)): x, y = map(jnp.array, (x, y)) params, state = f.init(None, x) out, state = f.apply(params, state, None, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y) @test_utils.transform_and_run def test_cond_traces_branches_with_same_id_once(self): witness = [] def f(x): witness.append(None) return x ** 2 stateful.cond(False, f, f, 0) hk_call_count = len(witness) self.assertEqual(hk_call_count, 1) # Ensure we are in sync with JAX. del witness[:] jax.lax.cond(False, f, f, 0) jax_call_count = len(witness) self.assertEqual(hk_call_count, jax_call_count) def test_cond_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.cond() instead"): stateful.cond(x == 2, x, jnp.square, x, lambda x: jnp.square(x + 1)) def test_switch(self): def f(i, x): mod = SquareModule() branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)] return stateful.switch(i, branches, x) f = transform.transform_with_state(f) for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)): i, x, y = map(jnp.array, (i, x, y)) params, state = f.init(None, i, x) out, state = f.apply(params, state, None, i, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y) @parameterized.parameters(1, 2, 4, 8) @test_utils.transform_and_run def test_switch_traces_cases_with_same_id_once(self, n): f_witness = [] g_witness = [] def f(x): f_witness.append(None) return x ** 2 def g(x): g_witness.append(None) return x ** 2 stateful.switch(0, [f, g] * n, 2) f_hk_call_count = len(f_witness) g_hk_call_count = len(g_witness) self.assertEqual(f_hk_call_count, 1) self.assertEqual(g_hk_call_count, 1) # Ensure we are in sync with JAX. del f_witness[:], g_witness[:] jax.lax.switch(0, [f, g] * n, 2) f_jax_call_count = len(f_witness) g_jax_call_count = len(g_witness) self.assertEqual(f_hk_call_count, f_jax_call_count) self.assertEqual(f_hk_call_count, g_jax_call_count) def test_switch_no_transform(self): i = jnp.array(2) x = jnp.array(42.) with self.assertRaises(ValueError, msg="Use jax.switch() instead"): stateful.switch(i, [jnp.square] * 3, x) @test_utils.transform_and_run def test_difference_empty(self): before = stateful.internal_state() after = stateful.internal_state() self.assertEmpty(jax.tree_leaves(stateful.difference(before, after))) @parameterized.parameters(base.get_parameter, base.get_state) @test_utils.transform_and_run(run_apply=False) def test_difference_new(self, get_x): get_x("a", [], init=jnp.zeros) before = stateful.internal_state() b = get_x("b", [], init=jnp.zeros) after = stateful.internal_state() diff = stateful.difference(before, after) if get_x == base.get_state: self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(b, b)}}) else: self.assertEqual(diff.params, {"~": {"a": None, "b": b}}) self.assertEmpty(diff.state) self.assertIsNone(diff.rng) @test_utils.transform_and_run(run_apply=False) def test_difference_update_state(self): base.get_state("a", [], init=jnp.zeros) base.get_state("b", [], init=jnp.zeros) before = stateful.internal_state() base.set_state("b", jnp.ones([])) after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(0., 1.)}}) self.assertIsNone(diff.rng) @test_utils.transform_and_run(run_apply=False) def test_difference_rng(self): before = stateful.internal_state() base.next_rng_key() after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng) def test_scan_no_transform(self): xs = jnp.arange(3) with self.assertRaises(ValueError, msg="Use jax.scan() instead"): stateful.scan(lambda c, x: (c, x), (), xs) @parameterized.parameters(0, 1, 2, 4, 8) def test_scan_with_state(self, unroll_length): def f(xs): m = CountingModule() def sf(c, x): self.assertEqual(c, ()) return c, m(x) _, ys = stateful.scan(sf, (), xs) return ys f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) xs = jnp.arange(unroll_length) params, state = f.init(key, xs) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) ys, state = f.apply(params, state, key, xs) np.testing.assert_allclose(state["counting_module"]["count"], unroll_length, rtol=1e-4) np.testing.assert_allclose(ys, xs ** 2, rtol=1e-4) @parameterized.parameters(0, 1, 2, 8) @test_utils.transform_and_run def test_stateful_scan_with_rng_use(self, iteration_count): # TODO(lenamartens): remove when default changes to > 1. tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(c, x): for _ in range(10): _ = base.next_rng_key() return c, x base.reserve_rng_keys(5) _ = stateful.scan(body_fun, (), (), length=iteration_count) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @parameterized.parameters(0, 1, 2, 8) @test_utils.transform_and_run def test_stateful_fori_with_rng_use(self, iteration_count): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(_, x): for _ in range(10): _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.fori_loop(0, iteration_count, body_fun, 1) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @test_utils.transform_and_run def test_stateful_cond_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def true_branch(x): _ = base.next_rng_key() return x def false_branch(x): _ = base.next_rng_key() _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.cond(True, true_branch, false_branch, 0) _ = stateful.cond(False, true_branch, false_branch, 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @test_utils.transform_and_run def test_stateful_switch_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def branch_f(i): for _ in range(i): _ = base.next_rng_key() return i base.reserve_rng_keys(5) branches = [lambda _, i=i: branch_f(i) for i in range(5)] self.assertEqual(stateful.switch(3, branches, None), 3) self.assertEqual(stateful.switch(0, branches, None), 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @parameterized.parameters(*it.product((0, 1, 2, 4, 8), (1, 2, 3))) @test_utils.transform_and_run def test_fori(self, lower, n): upper = lower + n m = CountingModule() y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2) self.assertEqual(y, jnp.square(upper - 1)) self.assertEqual(m.count, upper - lower) @test_utils.transform_and_run def test_fori_traced_length(self): m = CountingModule() def f(lower, upper): y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2) return y # Because of the jit, lower and upper will be tracers. out = stateful.jit(f)(0, 3) self.assertEqual(out, 4) self.assertEqual(m.count, 3) def test_vmap(self): def g(x): return CountingModule()(x) def f(x): return stateful.vmap(g)(x) f = transform.transform_with_state(f) x = jnp.ones([4]) + 1 params, state = f.init(None, x) # State should not be mapped. self.assertEmpty(params) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 0) # The output should be mapped but state should not be. y, state = f.apply(params, state, None, x) self.assertEqual(y.shape, (4,)) np.testing.assert_allclose(y, x ** 2) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 1) def test_while_loop_rejected_in_init(self): def f(): stateful.while_loop(lambda x: x.all(), lambda x: not x, 1) f = transform.transform(f) with self.assertRaisesRegex( ValueError, "hk.while_loop does not support initialization"): f.init(None) def test_updating_state_in_cond_fails(self): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(m, lambda x: x, x) f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) with self.assertRaisesRegex( ValueError, "does not support.*set_state.*next_rng_key.*in.*cond_fun`"): f.apply(params, state, None, x) def test_rng_in_cond_fails(self): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x) f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) with self.assertRaisesRegex( ValueError, "does not support.*set_state.*next_rng_key.*in.*cond_fun`"): f.apply(params, state, jax.random.PRNGKey(42), x) @parameterized.parameters(0, 1, 2, 4, 8) def test_while_loop_with_state(self, iters): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: _, y = stateful.while_loop(lambda a: a[0] < iters, lambda a: (a[0] + 1, m(a[1])), (0, x)) return y f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], x, rtol=1e-4) y, state = f.apply(params, state, None, x) np.testing.assert_allclose(state["counting_module"]["count"], iters, rtol=1e-4) np.testing.assert_allclose(y, iters, rtol=1e-4) def test_named_call(self): def f(x): return stateful.named_call(SquareModule(), name="square")(x) x = jnp.array(2.) rng = jax.random.PRNGKey(42) init, apply = transform.transform_with_state(f) params, state = init(rng, x) y, state = jax.jit(apply)(params, state, rng, x) self.assertEqual(y, x ** 2) @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) def test_named_call_jax_transforms(self, jax_transform): f = jnp.sum x = jnp.array([1.]) unnamed_out = jax_transform(f)(x) named_out = jax_transform(stateful.named_call(f, name="test"))(x) self.assertEqual(unnamed_out, named_out) def test_static_argnums_named_call(self): f = stateful.named_call(lambda x, y: y if x else None, name="test") f = jax.jit(f, static_argnums=(0,)) out = f(True, 5) self.assertEqual(out, 5) def test_named_call_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an # abstract value. def f(not_a_jaxtype, a_jaxtype): # then Jax needs to try and evaluate the abstractified non-JaxType if not_a_jaxtype: return a_jaxtype return 0 f = stateful.named_call(f, name="test") out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) self.assertEqual(out, 1) @parameterized.parameters("hi", None, object(), object) def test_named_call_non_jaxtype_result(self, non_jaxtype): def fun_with_non_jaxtype_output(x, non_jaxtype): return x, non_jaxtype def jitted_fun(x, non_jaxtype): named_fun = stateful.named_call(fun_with_non_jaxtype_output) # The non-jaxtype is returned out of named_call (which is supported), # but is not returned out of the jit (which should not be supported). x, non_jaxtype_out = named_fun(x, non_jaxtype) self.assertEqual(non_jaxtype_out, non_jaxtype) return x jitted_fun = jax.jit(jitted_fun, static_argnums=1) self.assertEqual(jitted_fun(0, non_jaxtype), 0) def test_named_call_partial_function(self): f = stateful.named_call(lambda x, y: y if x else None) f = jax.jit(functools.partial(f, True)) out = f(5) self.assertEqual(out, 5) def test_named_call_default_name(self): @stateful.named_call def naming_things_is_hard(x): return x ** 2 @jax.jit def f(x): return naming_things_is_hard(x) + naming_things_is_hard(x) c = jax.xla_computation(f)(2) self.assertIn("naming_things_is_hard", c.as_hlo_text()) def test_eval_shape(self): def some_shape_changing_fun(x): return x[0, :] def f(x): m = CountingModule(op=some_shape_changing_fun) # state is not changed in this call out_shape_struct = stateful.eval_shape(m, x) return m(x), out_shape_struct f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) in_shape = (10, 10) x = jnp.ones(in_shape) params, state = f.init(key, x) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) (out, shape_struct), state = f.apply(params, state, key, x) # Count is only advanced once np.testing.assert_allclose(state["counting_module"]["count"], 1, rtol=1e-4) np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4) self.assertEqual(shape_struct.shape, (in_shape[1],)) def test_eval_shape_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"): stateful.eval_shape(jnp.square)(x) @test_utils.transform_and_run def test_temporary_state_resets_names(self): with stateful.temporary_internal_state(stateful.internal_state()): mod1 = module.Module(name="foo") mod2 = module.Module(name="foo") self.assertEqual(mod1.module_name, "foo") self.assertEqual(mod2.module_name, "foo") @test_utils.transform_and_run(run_apply=False) def test_eval_shape_no_leaked_tracers_under_leak_checker(self): with jax.checking_leaks(): stateful.eval_shape(SquareModule(), jnp.ones(())) # does not crash
class HaikuTransformsTest(parameterized.TestCase): @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) def test_hk_jit(self, module_fn: ModuleFn, shape, dtype, init): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, jit=False): mod = module_fn() if jit: mod = hk.jit(mod) return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: jax.tree_multimap(assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) else: params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x, jit=True)) @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) def test_hk_scan(self, module_fn: descriptors.ModuleFn, shape, dtype, init): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def f(x): mod = module_fn() return mod(x) def u_f(xs): mod = module_fn() def s(carry, x): y = mod(x) return carry, y _, ys = hk.scan(s, (), xs) return ys u_f = hk.transform_with_state(u_f) f = hk.transform_with_state(f) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) xs = jnp.broadcast_to(x, (8, ) + x.shape) params, state = f.init(rng, x) if init: u_params, u_state = u_f.init(rng, xs) jax.tree_multimap(assert_allclose, u_params, params) jax.tree_multimap(assert_allclose, u_state, state) return def fun(state, x): y, state = f.apply(params, state, rng, x) return state, y s_state, s_ys = jax.lax.scan(fun, state, xs) u_ys, u_state = u_f.apply(params, state, rng, xs) jax.tree_multimap(assert_allclose, u_ys, s_ys) jax.tree_multimap(assert_allclose, u_state, s_state) @test_utils.combined_named_parameters( # TODO(tomhennigan) Enable once grad for _scan_transpose implemented. set(descriptors.ALL_MODULES) - set(descriptors.RECURRENT_MODULES)) def test_hk_remat(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x)) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_profiler_name_scopes(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, name_scopes=False): hk.experimental.profiler_name_scopes(enabled=name_scopes) mod = module_fn() return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), f.apply(params, state, rng, x, name_scopes=True)) # TODO(lenamartens): flip to True when default changes hk.experimental.profiler_name_scopes(enabled=False) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_optimize_rng_use_under_jit(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(hk.experimental.optimize_rng_use(g)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) params, state = jax.jit(f.init)(rng, x) jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x)) if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): # For stochastic modules just test apply runs. jax.device_get(jax.jit(f.apply)(params, state, rng, x)) else: jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x)) @test_utils.combined_named_parameters(descriptors.OPTIONAL_BATCH_MODULES) def test_vmap(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) # Expand our input since we will map over it. x = jnp.broadcast_to(x, (2, ) + x.shape) f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda f_mapped = hk.transform_with_state( lambda x: hk.vmap(lambda x: module_fn()(x))(x)) # pylint: disable=unnecessary-lambda params, state = f_mapped.init(rng, x) # JAX vmap with explicitly unmapped params/state/rng. This should be # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map # params/state/rng). v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0), out_axes=(0, None)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) jax.tree_multimap(assert_allclose, f_mapped.apply(params, state, rng, x), v_apply(params, state, rng, x))
class BatchNormTest(parameterized.TestCase): @test_utils.transform_and_run def test_basic(self): data = jnp.arange(2 * 3 * 4, dtype=jnp.float32).reshape([2, 3, 4]) norm = batch_norm.BatchNorm(True, True, 0.9) result = norm(data, is_training=True) result_0_replicated = jnp.broadcast_to(result[:, :, :1], result.shape) # Input data is symmetrical variance per-channel. np.testing.assert_allclose(result, result_0_replicated) # Running through again in test mode produces same output. np.testing.assert_allclose(norm(data, is_training=False), result, rtol=2e-2) @test_utils.transform_and_run def test_simple_training(self): layer = batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9) inputs = np.ones([2, 3, 3, 5]) scale = np.full((5,), 0.5) offset = np.full((5,), 2.0) result = layer(inputs, True, scale=scale, offset=offset) np.testing.assert_equal(result, np.full(inputs.shape, 2.0)) @test_utils.transform_and_run def test_simple_training_nchw(self): layer = batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9, data_format="NCHW") inputs = np.ones([2, 5, 3, 3]) scale = np.full((5, 1, 1), 0.5) offset = np.full((5, 1, 1), 2.0) result = layer(inputs, True, scale=scale, offset=offset) np.testing.assert_equal(result, np.full(inputs.shape, 2.0)) @test_utils.transform_and_run def test_simple_training_normalized_axes(self): layer = batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9, axis=[0, 2, 3]) # Not the second axis. # This differs only in the second axis. inputs = np.stack([2.0 * np.ones([5, 3, 3]), np.ones([5, 3, 3])], 1) result = layer(inputs, True) # Despite not all values being identical, treating slices from the first # axis separately leads to a fully normalized = equal array. np.testing.assert_equal(result, np.zeros(inputs.shape)) @test_utils.transform_and_run def test_no_scale_and_offset(self): layer = batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9) inputs = jnp.ones([2, 5, 3, 3, 3]) result = layer(inputs, True) np.testing.assert_equal(result, np.zeros_like(inputs)) @test_utils.transform_and_run def test_no_scale_and_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `scale_init` if `create_scale=False`"): batch_norm.BatchNorm( create_scale=False, create_offset=True, decay_rate=0.9, scale_init=jnp.ones) @test_utils.transform_and_run def test_no_offset_beta_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `offset_init` if `create_offset=False`"): batch_norm.BatchNorm( create_scale=True, create_offset=False, decay_rate=0.9, offset_init=jnp.zeros) @test_utils.combined_named_parameters( test_utils.named_bools("is_training"), test_utils.named_bools("test_local_stats")) def test_inits_ema_not_is_training(self, is_training, test_local_stats): def f(x): net = batch_norm.BatchNorm(True, True, 0.9) return net(x, is_training, test_local_stats) f = transform.transform_with_state(f) x = jnp.ones([]) _, state = f.init(None, x) if not is_training and test_local_stats: self.assertEmpty(state) else: self.assertLen(state, 2) for ema_name in ("mean_ema", "var_ema"): self.assertEqual(set(state[f"batch_norm/~/{ema_name}"]), {"counter", "average", "hidden"})
class StatefulTest(parameterized.TestCase): @test_utils.transform_and_run def test_grad(self): x = jnp.array(3.) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.grad(jnp.square)(x) @test_utils.transform_and_run def test_value_and_grad(self): x = jnp.array(2.) y, g = stateful.value_and_grad(SquareModule())(x) self.assertEqual(y, x ** 2) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_value_and_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.value_and_grad(jnp.square)(x) @test_utils.transform_and_run def test_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) g, aux = stateful.grad(f, has_aux=True)(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o) @test_utils.transform_and_run def test_value_and_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x) self.assertEqual(y, jnp.power(x, 2)) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o) def test_grad_and_jit(self): def f(x): g = stateful.grad(SquareModule())(x) return g x = jnp.array(3.) f = transform.transform_with_state(f) params, state = jax.jit(f.init)(None, x) g, state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) def test_value_and_grad_and_jit(self): def f(x): y, g = stateful.value_and_grad(SquareModule())(x) return y, g x = jnp.array(3.) f = transform.transform_with_state(f) params, state = jax.jit(f.init)(None, x) (y, g), state = jax.jit(f.apply)(params, state, None, x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) @test_utils.transform_and_run def test_jit(self): mod = SquareModule() x = jnp.array(2) y = stateful.jit(mod)(x) self.assertEqual(y, x ** 2) def test_jit_no_transform(self): x = jnp.array(2) with self.assertRaises(ValueError, msg="Use jax.jit() instead"): stateful.jit(jnp.square)(x) @test_utils.transform_and_run def test_remat(self): forward, backward = [], [] callback = _callback_prim(lambda: forward.append(None), lambda: backward.append(None)) def test(remat): x = jnp.array(3.) mod = CountingModule() self.assertEqual(mod.count, 0) f = lambda x: callback(mod(x)) if remat: f = stateful.remat(f) y, g = stateful.value_and_grad(f)(x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) self.assertEqual(mod.count, 1) num_forward = len(forward) num_backward = len(backward) del forward[:], backward[:] return num_forward, num_backward # Sanity check. self.assertEqual(test(remat=True), test(remat=True)) self.assertEqual(test(remat=False), test(remat=False)) # NOTE: JAX does not guarantee to execute primitives once and only once for # a given function (we observe f=2,b=1 without remat and f=5,b=1 with # remat), but we do expect that JAX will execute our primitive forward at # least one more time with remat than without it. num_forward_remat, num_backward_remat = test(remat=True) num_forward_no_remat, num_backward_no_remat = test(remat=False) self.assertGreater(num_forward_remat, num_forward_no_remat) self.assertEqual(num_backward_remat, num_backward_no_remat) def test_remat_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.remat() instead"): stateful.remat(jnp.square)(x) @test_utils.combined_named_parameters( test_utils.named_bools("jax_remat"), test_utils.named_bools("inline_hk_remat")) def test_create_module_inside_remat(self, jax_remat, inline_hk_remat): log = [] def forward(x): def create_and_use_layer(x): m = SquareModule(name="layer") log.append(m.module_name) return m(x) if not inline_hk_remat: create_and_use_layer = stateful.remat(create_and_use_layer) for _ in range(2): if inline_hk_remat: x = stateful.remat(create_and_use_layer)(x) else: x = create_and_use_layer(x) return x def reset(): del log[:] self.assertEmpty(log) # Test forward. x = jnp.float32(3) forward = transform.transform_with_state(forward) params, state = forward.init(None, x) self.assertEqual(log, ["layer", "layer_1"]) reset() # Test backward. for _ in range(3): grad_fn = jax.grad(lambda x: forward.apply(params, state, None, x)[0]) if jax_remat: grad_fn = jax.remat(grad_fn) self.assertEqual(int(grad_fn(x)), int(4 * (x ** 3))) self.assertEqual(log, ["layer", "layer_1"]) reset() @parameterized.parameters(True, False) def test_cond(self, single_arg): def f(x): mod = SquareModule() if single_arg: return stateful.cond(x == 2, mod, lambda x: mod(x + 1), x) else: return stateful.cond(x == 2, x, mod, x, lambda x: mod(x + 1)) f = transform.transform_with_state(f) for x, y in ((1, 4), (2, 4), (3, 16)): x, y = map(jnp.array, (x, y)) params, state = f.init(None, x) out, state = f.apply(params, state, None, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y) @test_utils.transform_and_run def test_cond_traces_branches_with_same_id_once(self): witness = [] def f(x): witness.append(None) return x ** 2 stateful.cond(False, f, f, 0) hk_call_count = len(witness) self.assertEqual(hk_call_count, 1) # Ensure we are in sync with JAX. del witness[:] jax.lax.cond(False, f, f, 0) jax_call_count = len(witness) self.assertEqual(hk_call_count, jax_call_count) @test_utils.transform_and_run def test_cond_no_args(self): x = stateful.cond(True, lambda: 5, lambda: 4) self.assertEqual(x, 5) @test_utils.transform_and_run def test_cond_operand_kwarg(self): x = stateful.cond(True, lambda x: x + 5, lambda x: x + 4, operand=1) self.assertEqual(x, 6) @test_utils.transform_and_run def test_cond_operand_kwarg_and_operands(self): with self.assertRaisesRegex(ValueError, "cannot.*pass.*positionally"): stateful.cond(True, lambda x: x + 5, lambda x: x + 4, 1, operand=1) @test_utils.transform_and_run def test_cond_two_args(self): a, b = stateful.cond(True, lambda a, b: (b, a), lambda a, b: (a, b), 2, 1) self.assertEqual(a, 1) self.assertEqual(b, 2) @test_utils.transform_and_run def test_cond_three_args(self): a, b, c = stateful.cond(True, lambda a, b, c: (c, b, a), lambda a, b, c: (a, b, c), 3, 2, 1) self.assertEqual(a, 1) self.assertEqual(b, 2) self.assertEqual(c, 3) def test_cond_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.cond() instead"): stateful.cond(x == 2, x, jnp.square, x, lambda x: jnp.square(x + 1)) def test_switch(self): def f(i, x): mod = SquareModule() branches = [mod, lambda x: mod(x + 1), lambda x: mod(x + 2)] return stateful.switch(i, branches, x) f = transform.transform_with_state(f) for i, x, y in ((0, 1, 1), (1, 2, 9), (2, 3, 25)): i, x, y = map(jnp.array, (i, x, y)) params, state = f.init(None, i, x) out, state = f.apply(params, state, None, i, x) self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y) @parameterized.parameters(1, 2, 4, 8) @test_utils.transform_and_run def test_switch_traces_cases_with_same_id_once(self, n): f_witness = [] g_witness = [] def f(x): f_witness.append(None) return x ** 2 def g(x): g_witness.append(None) return x ** 2 stateful.switch(0, [f, g] * n, 2) f_hk_call_count = len(f_witness) g_hk_call_count = len(g_witness) self.assertEqual(f_hk_call_count, 1) self.assertEqual(g_hk_call_count, 1) # Ensure we are in sync with JAX. del f_witness[:], g_witness[:] jax.lax.switch(0, [f, g] * n, 2) f_jax_call_count = len(f_witness) g_jax_call_count = len(g_witness) self.assertEqual(f_hk_call_count, f_jax_call_count) self.assertEqual(f_hk_call_count, g_jax_call_count) def test_switch_no_transform(self): i = jnp.array(2) x = jnp.array(42.) with self.assertRaises(ValueError, msg="Use jax.switch() instead"): stateful.switch(i, [jnp.square] * 3, x) @test_utils.transform_and_run def test_difference_empty(self): before = stateful.internal_state() after = stateful.internal_state() self.assertEmpty(jax.tree_leaves(stateful.difference(before, after))) @parameterized.parameters(base.get_parameter, base.get_state) @test_utils.transform_and_run(run_apply=False) def test_difference_new(self, get_x): get_x("a", [], init=jnp.zeros) before = stateful.internal_state() b = get_x("b", [], init=jnp.zeros) after = stateful.internal_state() diff = stateful.difference(before, after) if get_x == base.get_state: self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(b, b)}}) else: self.assertEqual(diff.params, {"~": {"a": None, "b": b}}) self.assertEmpty(diff.state) self.assertIsNone(diff.rng) @test_utils.transform_and_run(run_apply=False) def test_difference_update_state(self): base.get_state("a", [], init=jnp.zeros) base.get_state("b", [], init=jnp.zeros) before = stateful.internal_state() base.set_state("b", jnp.ones([])) after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(0., 1.)}}) self.assertIsNone(diff.rng) @test_utils.transform_and_run(run_apply=False) def test_difference_rng(self): before = stateful.internal_state() base.next_rng_key() after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng) def test_scan_no_transform(self): xs = jnp.arange(3) with self.assertRaises(ValueError, msg="Use jax.scan() instead"): stateful.scan(lambda c, x: (c, x), (), xs) @parameterized.parameters(0, 1, 2, 4, 8) def test_scan_with_state(self, unroll_length): def f(xs): m = CountingModule() def sf(c, x): self.assertEqual(c, ()) return c, m(x) _, ys = stateful.scan(sf, (), xs) return ys f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) init_key, apply_key = jax.random.split(key) xs = jnp.arange(unroll_length) params, state = f.init(init_key, xs) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) ys, state = f.apply(params, state, apply_key, xs) np.testing.assert_allclose(state["counting_module"]["count"], unroll_length, rtol=1e-4) np.testing.assert_allclose(ys, xs ** 2, rtol=1e-4) @parameterized.parameters(0, 1, 2, 8) @test_utils.transform_and_run def test_stateful_scan_with_rng_use(self, iteration_count): # TODO(lenamartens): remove when default changes to > 1. tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(c, x): for _ in range(10): _ = base.next_rng_key() return c, x base.reserve_rng_keys(5) _ = stateful.scan(body_fun, (), (), length=iteration_count) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @parameterized.parameters(0, 1, 2, 8) @test_utils.transform_and_run def test_stateful_fori_with_rng_use(self, iteration_count): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 def body_fun(_, x): for _ in range(10): _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.fori_loop(0, iteration_count, body_fun, 1) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @test_utils.transform_and_run def test_stateful_cond_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def true_branch(x): _ = base.next_rng_key() return x def false_branch(x): _ = base.next_rng_key() _ = base.next_rng_key() return x base.reserve_rng_keys(5) _ = stateful.cond(True, true_branch, false_branch, 0) _ = stateful.cond(False, true_branch, false_branch, 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @test_utils.transform_and_run def test_stateful_switch_with_rng_use(self): tmp_default = base.DEFAULT_PRNG_RESERVE_SIZE base.DEFAULT_PRNG_RESERVE_SIZE = 64 # Test if using different amount of keys in different branches # results in error def branch_f(i): for _ in range(i): _ = base.next_rng_key() return i base.reserve_rng_keys(5) branches = [lambda _, i=i: branch_f(i) for i in range(5)] self.assertEqual(stateful.switch(3, branches, None), 3) self.assertEqual(stateful.switch(0, branches, None), 0) base.DEFAULT_PRNG_RESERVE_SIZE = tmp_default @parameterized.parameters(*it.product((0, 1, 2, 4, 8), (1, 2, 3))) @test_utils.transform_and_run def test_fori(self, lower, n): upper = lower + n m = CountingModule() y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2) self.assertEqual(y, jnp.square(upper - 1)) self.assertEqual(m.count, upper - lower) @test_utils.transform_and_run def test_fori_traced_length(self): m = CountingModule() def f(lower, upper): y = stateful.fori_loop(lower, upper, lambda i, x: m(i), 2) return y # Because of the jit, lower and upper will be tracers. out = stateful.jit(f)(0, 3) self.assertEqual(out, 4) self.assertEqual(m.count, 3) def test_vmap(self): def g(x): return CountingModule()(x) def f(x): return stateful.vmap(g, split_rng=False)(x) f = transform.transform_with_state(f) x = jnp.ones([4]) + 1 params, state = f.init(None, x) # State should not be mapped. self.assertEmpty(params) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 0) # The output should be mapped but state should not be. y, state = f.apply(params, state, None, x) self.assertEqual(y.shape, (4,)) np.testing.assert_allclose(y, x ** 2) cnt, = jax.tree_leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 1) def test_vmap_must_be_called_in_transform(self): f = stateful.vmap(lambda x: x, split_rng=False) with self.assertRaisesRegex(ValueError, "must be used as part of an.*hk.transform"): f(0) @test_utils.transform_and_run def test_vmap_no_in_axes(self): def fn_name(_): pass with self.assertRaisesRegex( ValueError, "fn_name must have at least one non-None value in in_axes"): stateful.vmap(fn_name, in_axes=None, split_rng=False) @test_utils.transform_and_run def test_vmap_in_axes_different_size(self): x = jnp.ones([1, 2]) with self.assertRaisesRegex( ValueError, "vmap got inconsistent sizes for array axes to be mapped"): stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x) @test_utils.transform_and_run def test_vmap_no_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() np.testing.assert_array_equal(k1, k2) np.testing.assert_array_equal(k2, k3) np.testing.assert_array_equal(k3, k4) self.assertFalse(np.array_equal(key_before, k1)) self.assertFalse(np.array_equal(key_after, k1)) self.assertFalse(np.array_equal(key_before, key_after)) @test_utils.transform_and_run def test_vmap_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() # Test that none of the keys are equal. named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4), ("key_before", key_before), ("key_after", key_after)) for (a_name, a), (b_name, b) in it.combinations(named_keys, 2): self.assertFalse( np.array_equal(a, b), msg=f"Keys should not be equal, but {a_name} == {b_name}") def test_while_loop_rejected_in_init(self): def f(): stateful.while_loop(lambda x: x.all(), lambda x: not x, 1) f = transform.transform(f) with self.assertRaisesRegex( ValueError, "hk.while_loop does not support initialization"): f.init(None) def test_updating_state_in_cond_fails(self): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(m, lambda x: x, x) f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) with self.assertRaisesRegex( ValueError, "does not support.*set_state.*next_rng_key.*in.*cond_fun`"): f.apply(params, state, None, x) def test_rng_in_cond_fails(self): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x) f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) with self.assertRaisesRegex( ValueError, "does not support.*set_state.*next_rng_key.*in.*cond_fun`"): f.apply(params, state, jax.random.PRNGKey(42), x) @parameterized.parameters(0, 1, 2, 4, 8) def test_while_loop_with_state(self, iters): def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: _, y = stateful.while_loop(lambda a: a[0] < iters, lambda a: (a[0] + 1, m(a[1])), (0, x)) return y f = transform.transform_with_state(f) x = jnp.zeros([]) params, state = f.init(None, x) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], x, rtol=1e-4) y, state = f.apply(params, state, None, x) np.testing.assert_allclose(state["counting_module"]["count"], iters, rtol=1e-4) np.testing.assert_allclose(y, iters, rtol=1e-4) def test_named_call(self): def f(x): return stateful.named_call(SquareModule(), name="square")(x) x = jnp.array(2.) rng = jax.random.PRNGKey(42) init, apply = transform.transform_with_state(f) params, state = init(rng, x) y, state = jax.jit(apply)(params, state, rng, x) self.assertEqual(y, x ** 2) @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) def test_named_call_jax_transforms(self, jax_transform): f = jnp.sum x = jnp.array([1.]) unnamed_out = jax_transform(f)(x) named_out = jax_transform(stateful.named_call(f, name="test"))(x) self.assertEqual(unnamed_out, named_out) def test_static_argnums_named_call(self): f = stateful.named_call(lambda x, y: y if x else None, name="test") f = jax.jit(f, static_argnums=(0,)) out = f(True, 5) self.assertEqual(out, 5) def test_named_call_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an # abstract value. def f(not_a_jaxtype, a_jaxtype): # then Jax needs to try and evaluate the abstractified non-JaxType if not_a_jaxtype: return a_jaxtype return 0 f = stateful.named_call(f, name="test") out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) self.assertEqual(out, 1) @parameterized.parameters("hi", None, object(), object) def test_named_call_non_jaxtype_result(self, non_jaxtype): def fun_with_non_jaxtype_output(x, non_jaxtype): return x, non_jaxtype def jitted_fun(x, non_jaxtype): named_fun = stateful.named_call(fun_with_non_jaxtype_output) # The non-jaxtype is returned out of named_call (which is supported), # but is not returned out of the jit (which should not be supported). x, non_jaxtype_out = named_fun(x, non_jaxtype) self.assertEqual(non_jaxtype_out, non_jaxtype) return x jitted_fun = jax.jit(jitted_fun, static_argnums=1) self.assertEqual(jitted_fun(0, non_jaxtype), 0) def test_named_call_partial_function(self): f = stateful.named_call(lambda x, y: y if x else None) f = jax.jit(functools.partial(f, True)) out = f(5) self.assertEqual(out, 5) def test_named_call_default_name(self): @stateful.named_call def naming_things_is_hard(x): return x ** 2 @jax.jit def f(x): return naming_things_is_hard(x) + naming_things_is_hard(x) c = jax.xla_computation(f)(1.) print_opts = jax.xla.xe.HloPrintOptions.short_parsable() print_opts.print_metadata = True hlo_text = c.as_hlo_module().to_string(print_opts) self.assertIn("naming_things_is_hard", hlo_text) def test_eval_shape(self): def some_shape_changing_fun(x): return x[0, :] def f(x): m = CountingModule(op=some_shape_changing_fun) # state is not changed in this call out_shape_struct = stateful.eval_shape(m, x) return m(x), out_shape_struct f = transform.transform_with_state(f) key = jax.random.PRNGKey(42) in_shape = (10, 10) x = jnp.ones(in_shape) params, state = f.init(key, x) self.assertEqual(list(state), ["counting_module"]) self.assertEqual(list(state["counting_module"]), ["count"]) np.testing.assert_allclose(state["counting_module"]["count"], 0, rtol=1e-4) (out, shape_struct), state = f.apply(params, state, key, x) # Count is only advanced once np.testing.assert_allclose(state["counting_module"]["count"], 1, rtol=1e-4) np.testing.assert_allclose(out, some_shape_changing_fun(x), rtol=1e-4) self.assertEqual(shape_struct.shape, (in_shape[1],)) def test_eval_shape_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"): stateful.eval_shape(jnp.square)(x) @test_utils.transform_and_run def test_temporary_state_resets_names(self): with stateful.temporary_internal_state(stateful.internal_state()): mod1 = module.Module(name="foo") mod2 = module.Module(name="foo") self.assertEqual(mod1.module_name, "foo") self.assertEqual(mod2.module_name, "foo") @test_utils.transform_and_run(run_apply=False) def test_eval_shape_no_leaked_tracers_under_leak_checker(self): with jax.checking_leaks(): stateful.eval_shape(SquareModule(), jnp.ones(())) # does not crash @test_utils.combined_named_parameters(base_test.SIDE_EFFECTING_FUNCTIONS, HK_OVERLOADED_JAX_PURE_EXPECTING_FNS) @test_utils.transform_and_run @test_utils.with_guardrails def test_safe_use_of_jax(self, haiku_side_effect_fn, hk_jax_fn): if "reserve_rng_keys_while_loop" in self._testMethodName: self.skipTest("Expected not to work.") # Make `f` identify with the side effecting function included. f = hk_jax_fn(lambda x: [haiku_side_effect_fn(), x][1]) x = jnp.ones([1]) # These functions should not trigger exceptions from our guardrails. f(x) @test_utils.transform_and_run def test_vmap_split_rng_with_default(self): with self.assertRaisesRegex(TypeError, "hk.vmap.require_split_rng = False"): # Intentionally missing split_rng arg. stateful.vmap(lambda: None) with self.subTest("require_split_rng=0"): stateful.vmap.require_split_rng = False try: # This call should not trigger an error, even though we are missing the # split_rng argument which appears required (if you look at the function # signature). It only works because require_split_rng is # propagated to vmap via a sneaky decorator. This only exists to support # users who import code that they cannot edit (e.g. from a read only # file system) that is not passing the argument. f = stateful.vmap(base.next_rng_key, axis_size=2) finally: stateful.vmap.require_split_rng = True # Check that split_rng=False was implied. k1, k2 = f() self.assertTrue((k1 == k2).all()) @parameterized.parameters(True, False) @test_utils.transform_and_run def test_vmap_split_rng_without_default(self, require_split_rng): # Tests that when split_rng is passed explicitly the value of # require_split_rng has no impact. x = jnp.arange(2) stateful.vmap.require_split_rng = require_split_rng k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x) self.assertTrue((k1 != k2).all()) k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x) self.assertTrue((k1 == k2).all()) stateful.vmap.require_split_rng = True