Example #1
0
def fast_eval_shape(fun, *args, **kwargs):
    """Equivalent to ``eval_shape`` in JAX.

  This utility is equivalent to ``eval_shape`` in JAX except that it avoids
  running Haiku functions whose shapes are trivially known. This can avoid some
  Python overheads in JAX which can accumulate for very large models.

  Optimizations:

  * All parameter/state initialisers replaced with zeros.
  * ``hk.dropout`` replaced with identity.
  * ``jax.random.fold_in`` replaced with identity.

  Args:
    fun: The function to trace.
    *args: Positional arguments to ``fun``.
    **kwargs: Keyword arguments to ``fun``.

  Returns:
    The shape produced by ``fun`` for the given args/kwargs.
  """
    with base.custom_creator_unsafe(zeros_creator), \
         mock.patch.object(basic, 'dropout_impl', noop_dropout), \
         mock.patch.object(jax.random, 'fold_in', lambda key, data: key):
        if base.inside_transform():
            return stateful.eval_shape(fun, *args, **kwargs)
        else:
            return jax.eval_shape(fun, *args, **kwargs)
Example #2
0
 def test_eval_shape_no_leaked_tracers_under_leak_checker(self):
   with jax.checking_leaks():
     stateful.eval_shape(SquareModule(), jnp.ones(()))  # does not crash
Example #3
0
 def test_eval_shape_no_transform(self):
   x = jnp.array(3.)
   with self.assertRaises(ValueError, msg="Use jax.eval_shape() instead"):
     stateful.eval_shape(jnp.square)(x)
Example #4
0
 def f(x):
   m = CountingModule(op=some_shape_changing_fun)
   # state is not changed in this call
   out_shape_struct = stateful.eval_shape(m, x)
   return m(x), out_shape_struct
Example #5
0
import jax
import jax.numpy as jnp
import numpy as np

toggle = lambda i, a: lambda x: a(x) if base.params_frozen() else i(x)


# JAX transforms and control flow that need to be aware of Haiku internal
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", stateful.jit),

    # ("make_jaxpr", stateful.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])),
    ("named_call", stateful.named_call),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).
    # ("pmap", lambda f: stateful.pmap(f, "i")),

    # Vectorization.
    ("vmap", lambda f: stateful.vmap(f, split_rng=False)),

    # Control flow.
    # TODO(tomhennigan): Enable for associative_scan.
    # ("associative_scan", lambda f:
    #  (lambda x: jax.lax.associative_scan(f, x))),
    ("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))),
    ("fori_loop", lambda f:
Example #6
0
 def f(x):
     m = basic.Linear(20)
     y_slow = stateful.eval_shape(m, x)
     y_fast = eval_shape.fast_eval_shape(m, x)
     self.assertEqual(y_slow, y_fast)
     return m(x)