def test_optimize_rng_use_under_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(hk.experimental.optimize_rng_use(g)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) params, state = jax.jit(f.init)(rng, x) jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x)) if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): # For stochastic modules just test apply runs. jax.device_get(jax.jit(f.apply)(params, state, rng, x)) else: jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x))
def test_jit(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(g) atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn), DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) # Ensure initialization under jit is the same. jax.tree_multimap(assert_allclose, f.init(rng, x), jax.jit(f.init)(rng, x)) # Ensure application under jit is the same. params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), jax.jit(f.apply)(params, state, rng, x))
def test_vmap(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) # Expand our input since we will map over it. x = jnp.broadcast_to(x, (2, ) + x.shape) f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda f_mapped = hk.transform_with_state( lambda x: hk.vmap(lambda x: module_fn()(x))(x)) # pylint: disable=unnecessary-lambda params, state = f_mapped.init(rng, x) # JAX vmap with explicitly unmapped params/state/rng. This should be # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map # params/state/rng). v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0), out_axes=(0, None)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) jax.tree_multimap(assert_allclose, f_mapped.apply(params, state, rng, x), v_apply(params, state, rng, x))
def test_info_and_html(self, module_fn: ModuleFn, shape, dtype): x = jnp.ones(shape, dtype) f = hk.transform_with_state(lambda: module_fn()(x)) rng = jax.random.PRNGKey(42) params, state = f.init(rng) info = jaxpr_info.make_model_info(f.apply)(params, state, rng) if descriptors.module_type(module_fn).__name__ != 'Sequential': self.assertNotEmpty(info.expressions) self.assertIsNotNone(jaxpr_info.as_html_page(info))
def test_strict_promotion(self, module_fn: ModuleFn, shape, dtype): if descriptors.module_type(module_fn) in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): self.skipTest('Requires: https://github.com/google/jax/pull/2901') f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = jnp.ones(shape, dtype) params, state = f.init(rng, x) self.assertIsNotNone(f.apply(params, state, rng, x))
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") for descriptor in descriptors.ALL_MODULES: cls = descriptors.module_type(descriptor.create) file_name = descriptors.to_file_name(descriptor) + ".json" summary = checkpoint_utils.summarize(descriptor) with open(os.path.join(FLAGS.base_dir, file_name), "w") as fp: fp.write(json.dumps(summary, indent=2)) fp.write("\n")
def test_checkpoint_format(self, name, module_fn: ModuleFn, shape, dtype): descriptor = descriptors.ModuleDescriptor(name, module_fn, shape, dtype) cls = descriptors.module_type(descriptor.create) expected = checkpoint_utils.summarize(descriptor) file_path = os.path.join( "haiku/_src/integration/checkpoints/", descriptors.to_file_name(descriptor) + ".json") if not os.path.exists(file_path): expected_json = json.dumps(expected, indent=2) raise ValueError(f"Missing checkpoint file: {file_path}\n\n" f"Expected:\n\n{expected_json}") with open(file_path, "r") as fp: actual = json.load(fp) self.assertEqual(expected, actual, msg=HOW_TO_REGENERATE)
def test_convert( self, module_fn: ModuleFn, shape, dtype, init: bool, tf_transform, jax_transform, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(g) atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn), DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) get = lambda t: jax.tree_map(lambda x: x.numpy(), t) if init: init_jax = jax_transform(f.init) init_tf = tf_transform(jax2tf.convert(f.init)) jax.tree_multimap(assert_allclose, init_jax(rng, x), get(init_tf(rng, x))) else: params, state = f.init(rng, x) apply_jax = jax_transform(f.apply) apply_tf = tf_transform(jax2tf.convert(f.apply)) jax.tree_multimap(assert_allclose, apply_jax(params, state, rng, x), get(apply_tf(params, state, rng, x)))