def fast_eval_shape(fun, *args, **kwargs): """Equivalent to ``eval_shape`` in JAX. This utility is equivalent to ``eval_shape`` in JAX except that it avoids running Haiku functions whose shapes are trivially known. This can avoid some Python overheads in JAX which can accumulate for very large models. Optimizations: * All parameter/state initialisers replaced with zeros. * ``hk.dropout`` replaced with identity. * ``jax.random.fold_in`` replaced with identity. Args: fun: The function to trace. *args: Positional arguments to ``fun``. **kwargs: Keyword arguments to ``fun``. Returns: The shape produced by ``fun`` for the given args/kwargs. """ with base.custom_creator_unsafe(zeros_creator), \ mock.patch.object(basic, 'dropout_impl', noop_dropout), \ mock.patch.object(jax.random, 'fold_in', lambda key, data: key): if base.inside_transform(): return stateful.eval_shape(fun, *args, **kwargs) else: return jax.eval_shape(fun, *args, **kwargs)
def test_eval_shape_no_leaked_tracers_under_leak_checker(self): with jax.checking_leaks(): stateful.eval_shape(SquareModule(), jnp.ones(())) # does not crash
def test_eval_shape_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"): stateful.eval_shape(jnp.square)(x)
def f(x): m = CountingModule(op=some_shape_changing_fun) # state is not changed in this call out_shape_struct = stateful.eval_shape(m, x) return m(x), out_shape_struct
import jax import jax.numpy as jnp import numpy as np toggle = lambda i, a: lambda x: a(x) if base.params_frozen() else i(x) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", stateful.jit), # ("make_jaxpr", stateful.make_jaxpr), ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])), ("named_call", stateful.named_call), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap). # ("pmap", lambda f: stateful.pmap(f, "i")), # Vectorization. ("vmap", lambda f: stateful.vmap(f, split_rng=False)), # Control flow. # TODO(tomhennigan): Enable for associative_scan. # ("associative_scan", lambda f: # (lambda x: jax.lax.associative_scan(f, x))), ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f:
def f(x): m = basic.Linear(20) y_slow = stateful.eval_shape(m, x) y_fast = eval_shape.fast_eval_shape(m, x) self.assertEqual(y_slow, y_fast) return m(x)