Example #1
0
def update_modified_frame_data(frame_data: FrameData):
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, frame_data.params)
  update_recursive_skip_none(frame.state, frame_data.state)
  if not params_frozen():
    update_recursive_skip_none(frame.constants, frame_data.constants)
  rng = frame_data.rng
  if rng is not None:
    frame.rng_stack.peek().replace_internal_state(rng)
Example #2
0
def update_modified_frame_data_from_args(params, bundled_state):
  (state, constants, rng) = bundled_state
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, params)
  update_recursive_skip_none(frame.state, state)
  if not params_frozen():
    update_recursive_skip_none(frame.constants, constants)
  rng = rng
  if rng is not None:
    frame.rng_stack.peek().replace_internal_state(rng)
Example #3
0
def fori_loop(lower, upper, body_fun, init_val):
    """Equivalent to :func:`jax.lax.fori_loop` with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.fori_loop() should not be used outside of hk.transform(). "
            "Use jax.lax.fori_loop() instead.")

    def pure_body_fun(i, val):
        state, val = val
        with temporary_internal_state(state):
            val = body_fun(i, val)
            reserve_up_to_full_rng_block()
            state = internal_state()
            return state, val

    if not base.params_frozen():
        # During init we need to unwind one step of the loop to ensure the Haiku
        # state before and after the body has the same structure.
        init_val = body_fun(lower, init_val)
        lower += 1
        try:
            if upper - lower == 0:
                return init_val
        except jax.errors.ConcretizationTypeError:
            # upper or lower might be tracers, which jax.lax.fori_loop can handle.
            pass

    reserve_up_to_full_rng_block()
    state = internal_state()
    init_val = state, init_val
    state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
    update_internal_state(state)
    return val
Example #4
0
 def f(x):
     m = CountingModule(op=lambda x: x + 1)
     if not base.params_frozen():
         return m(x)
     else:
         stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x,
                             x)
Example #5
0
def running_init() -> bool:
    """Return True if running the ``init`` function of a Haiku transform.

  In general you should not need to gate behaviour of your module based on
  whether you are running ``init`` or ``apply``, but sometimes (e.g. when making
  use of JAX control flow) this is required.

  For example, if you want to use :func:`switch` to pick between experts, when
  we run your init function we need to ensure that params/state for all experts
  are created (unconditionally) but during apply we want to conditionally apply
  (and perhaps update the internal state) of only one of our experts:

  >>> experts = [hk.nets.ResNet50(10) for _ in range(5)]
  >>> x = jnp.ones([1, 224, 224, 3])
  >>> if hk.running_init():
  ...   # During init unconditionally create params/state for all experts.
  ...   for expert in experts:
  ...     out = expert(x, is_training=True)
  ... else:
  ...   # During apply conditionally apply (and update) only one expert.
  ...   index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
  ...   out = hk.switch(index, experts, x)

  Returns:
    True if running ``init`` otherwise False.
  """
    base.assert_context("running_init")
    return not base.params_frozen()
Example #6
0
def fori_loop(lower, upper, body_fun, init_val):
    """Equivalent to ``jax.lax.fori_loop`` with Haiku state threaded in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.fori_loop() should not be used outside of hk.transform(). "
            "Use jax.lax.fori_loop() instead.")

    def pure_body_fun(i, val):
        state, val = val
        with temporary_internal_state(state):
            val = body_fun(i, val)
            state = internal_state()
            return state, val

    if not base.params_frozen():
        # During init we need to unwind one step of the loop to ensure the Haiku
        # state before and after the body has the same structure.
        init_val = body_fun(lower, init_val)
        lower += 1
        if upper - lower == 0:
            return init_val

    state = internal_state()
    init_val = state, init_val
    state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
    update_internal_state(state)
    return val
Example #7
0
 def f(x):
     m = CountingModule(op=lambda x: x + 1)
     if not base.params_frozen():
         return m(x)
     else:
         _, y = stateful.while_loop(lambda a: a[0] < iters, lambda a:
                                    (a[0] + 1, m(a[1])), (0, x))
         return y
Example #8
0
def scan(f, init, xs, length=None, reverse=False):
    """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.scan() should not be used outside of hk.transform(). "
            "Use jax.scan() instead.")

    if length is None:
        length = jax.tree_leaves(xs)[0].shape[0]

    running_init_fn = not base.params_frozen()

    if running_init_fn:
        # During `init` we need to unroll one step of the scan, this is because our
        # carry contains the Haiku state and during `init` this may change structure
        # (e.g. as state is created).
        if not length:
            x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs)
            _, y0 = f(init, x0)
            y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype),
                              y0)
            return init, y0

        if reverse:
            x0 = jax.tree_map(lambda x: x[-1], xs)
            xs = jax.tree_map(lambda x: x[:-1], xs)
        else:
            x0 = jax.tree_map(lambda x: x[0], xs)
            xs = jax.tree_map(lambda x: x[1:], xs)
        init, y0 = f(init, x0)
        y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0)
        length -= 1
        if not length:
            return init, y0

    def stateful_fun(carry, x):
        carry, state = carry
        with temporary_internal_state(state):
            with base.assert_no_new_parameters():
                carry, out = f(carry, x)
            carry = (carry, internal_state(params=False))
            return carry, out

    # We know that we don't need to thread params in and out, since for init we
    # have already created them (given that above we unroll one step of the scan)
    # and for apply we know they are immutable. As such we only need to thread the
    # state and rng in and out.
    init = (init, internal_state(params=False))
    (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse)
    update_internal_state(state)

    if running_init_fn:
        ys = jax.tree_multimap(lambda y0, ys: jnp.concatenate([y0, ys]), y0,
                               ys)

    return carry, ys
Example #9
0
 def __call__(self, *args, **kwargs):
     frame = base.current_frame()
     bundle_name = self.module_name
     if base.params_frozen():
         prefix = bundle_name + "/"
         lifted_params = unpack_from_dict(frame.params, prefix)
         return lifted_params
     else:  # Inside init.
         # Lift parameters into this transform's params_dict.
         params = self._init_fn(*args, **kwargs)
         pack_into_dict(params, frame.params, bundle_name)
         return params
Example #10
0
def conv_weight_with_spectral_norm(x: jnp.ndarray,
                                   kernel_shape: Sequence[int],
                                   out_channel: int,
                                   name_suffix: str = "",
                                   w_init: Callable = None,
                                   b_init: Callable = None,
                                   use_bias: bool = True,
                                   is_training: bool = True,
                                   update_params: bool = True,
                                   max_singular_value: float = 0.95,
                                   max_power_iters: int = 1,
                                   **conv_kwargs):
    batch_size, H, W, C = x.shape
    w_shape = kernel_shape + (C, out_channel)

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel),
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (H, W, C),
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                          conv_kwargs["padding"],
                                          max_singular_value, max_power_iters,
                                          update_params)

    # Run for a lot of steps when we're first initializing
    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                              conv_kwargs["padding"],
                                              max_singular_value, None, True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ),
                             x.dtype,
                             init=b_init)

    if use_bias:
        return w, b
    return w
Example #11
0
def weight_with_spectral_norm(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              update_params: bool = True,
                              use_bias: bool = True,
                              force_in_dim: Optional = None,
                              max_singular_value: float = 0.99,
                              max_power_iters: int = 1,
                              **kwargs):
    in_dim, dtype = x.shape[-1], x.dtype
    if force_in_dim:
        in_dim = force_in_dim

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim),
                         dtype,
                         init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ),
                             dtype,
                             init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (out_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (in_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value,
                                     max_power_iters, update_params)

    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None,
                                         True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        return w, b
    return w
Example #12
0
def while_loop(cond_fun, body_fun, init_val):
    """Equivalent to jax.lax.while_loop with Haiku state threaded in/out."""

    if not base.params_frozen():
        raise ValueError(
            "hk.while_loop does not support initialization (since we cannot "
            "statically determine if your loop will run at least once). Please "
            "use `hk.running_init` to run the body unconditionally:\n"
            "\n"
            "    if hk.running_init():\n"
            "      # Unconditionally connect the module at init time.\n"
            "      val = module(val)\n"
            "    else:\n"
            "      val = hk.while_loop(lambda val: val.mean() < 1, module, val)\n"
        )

    @functools.wraps(cond_fun)
    def pure_cond_fun(val):
        val, _ = val
        try:
            with base.assert_state_unchanged():
                return cond_fun(val)
        except base.StateChangedError as e:
            # If we find a use case for updating state/using rng in `cond` we would
            # need to make a change in JAX itself (to support aux in/out of the cond).
            raise ValueError(
                "`hk.while_loop` does not support `hk.set_state`, `hk.next_rng_key` "
                "(et al) in `cond_fun`.") from e

    @functools.wraps(body_fun)
    def pure_body_fun(val):
        val, state = val
        with temporary_internal_state(state), \
             base.push_jax_trace_level():
            val = body_fun(val)
            state = internal_state()
            return val, state

    init_val = (init_val, internal_state())
    val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
    update_internal_state(state)
    return val
Example #13
0
def update_frame_data(params, state):
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, params)
  update_recursive_skip_none(frame.state, state)
Example #14
0
 def init_if_needed(self, x, rng):
     # Before extracting the frame data, we need to make sure that the
     # network is initialized!
     running_init_fn = not hk_base.params_frozen()
     if running_init_fn:
         self.auto_batched_res_block(x, rng)
Example #15
0
def scan(f, init, xs, length=None, reverse=False, unroll=1):
    """Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.scan() should not be used outside of hk.transform(). "
            "Use jax.scan() instead.")

    if length is None:
        length = jax.tree_leaves(xs)[0].shape[0]

    running_init_fn = not base.params_frozen()

    if running_init_fn:
        # During `init` we need to unroll one step of the scan, this is because our
        # carry contains the Haiku state and during `init` this may change structure
        # (e.g. as state is created).
        if not length:
            x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs)
            _, y0 = f(init, x0)
            y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype),
                              y0)
            return init, y0

        if reverse:
            x0 = jax.tree_map(lambda x: x[-1], xs)
            xs = jax.tree_map(lambda x: x[:-1], xs)
        else:
            x0 = jax.tree_map(lambda x: x[0], xs)
            xs = jax.tree_map(lambda x: x[1:], xs)
        init, y0 = f(init, x0)
        y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0)
        length -= 1
        if not length:
            return init, y0

    @functools.wraps(f)
    def stateful_fun(carry, x):
        carry, state = carry
        with temporary_internal_state(state):
            with base.assert_no_new_parameters(), \
                 base.push_jax_trace_level():
                carry, out = f(carry, x)
            reserve_up_to_full_rng_block()
            carry = (carry, internal_state(params=False))
            return carry, out

    # Before pulling out the  internal state,  reserve a full block  of RNG keys.
    # This is to make sure we're always passing in the same amount of subkeys in
    # and out of the scan carry (scan requires equal length lists).
    # After every scan iteration we reserve back up to the full block.
    reserve_up_to_full_rng_block()

    # We know that we don't need to thread params in and out, since for init we
    # have already created them (given that above we unroll one step of the scan)
    # and for apply we know they are immutable. As such we only need to thread the
    # state and rng in and out.

    init = (init, internal_state(params=False))
    (carry, state), ys = jax.lax.scan(stateful_fun,
                                      init,
                                      xs,
                                      length,
                                      reverse,
                                      unroll=unroll)
    update_internal_state(state)

    if running_init_fn:
        if reverse:
            ys = jax.tree_map(lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys)
        else:
            ys = jax.tree_map(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys)

    return carry, ys
Example #16
0
import itertools as it

from absl.testing import absltest
from absl.testing import parameterized
from haiku._src import base
from haiku._src import base_test
from haiku._src import module
from haiku._src import stateful
from haiku._src import test_utils
from haiku._src import transform

import jax
import jax.numpy as jnp
import numpy as np

toggle = lambda i, a: lambda x: a(x) if base.params_frozen() else i(x)


# JAX transforms and control flow that need to be aware of Haiku internal
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", stateful.jit),

    # ("make_jaxpr", stateful.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])),
    ("named_call", stateful.named_call),

    # Parallelization.
    # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).