def update_modified_frame_data(frame_data: FrameData): frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, frame_data.params) update_recursive_skip_none(frame.state, frame_data.state) if not params_frozen(): update_recursive_skip_none(frame.constants, frame_data.constants) rng = frame_data.rng if rng is not None: frame.rng_stack.peek().replace_internal_state(rng)
def update_modified_frame_data_from_args(params, bundled_state): (state, constants, rng) = bundled_state frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, params) update_recursive_skip_none(frame.state, state) if not params_frozen(): update_recursive_skip_none(frame.constants, constants) rng = rng if rng is not None: frame.rng_stack.peek().replace_internal_state(rng)
def fori_loop(lower, upper, body_fun, init_val): """Equivalent to :func:`jax.lax.fori_loop` with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.fori_loop() should not be used outside of hk.transform(). " "Use jax.lax.fori_loop() instead.") def pure_body_fun(i, val): state, val = val with temporary_internal_state(state): val = body_fun(i, val) reserve_up_to_full_rng_block() state = internal_state() return state, val if not base.params_frozen(): # During init we need to unwind one step of the loop to ensure the Haiku # state before and after the body has the same structure. init_val = body_fun(lower, init_val) lower += 1 try: if upper - lower == 0: return init_val except jax.errors.ConcretizationTypeError: # upper or lower might be tracers, which jax.lax.fori_loop can handle. pass reserve_up_to_full_rng_block() state = internal_state() init_val = state, init_val state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val) update_internal_state(state) return val
def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x)
def running_init() -> bool: """Return True if running the ``init`` function of a Haiku transform. In general you should not need to gate behaviour of your module based on whether you are running ``init`` or ``apply``, but sometimes (e.g. when making use of JAX control flow) this is required. For example, if you want to use :func:`switch` to pick between experts, when we run your init function we need to ensure that params/state for all experts are created (unconditionally) but during apply we want to conditionally apply (and perhaps update the internal state) of only one of our experts: >>> experts = [hk.nets.ResNet50(10) for _ in range(5)] >>> x = jnp.ones([1, 224, 224, 3]) >>> if hk.running_init(): ... # During init unconditionally create params/state for all experts. ... for expert in experts: ... out = expert(x, is_training=True) ... else: ... # During apply conditionally apply (and update) only one expert. ... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1) ... out = hk.switch(index, experts, x) Returns: True if running ``init`` otherwise False. """ base.assert_context("running_init") return not base.params_frozen()
def fori_loop(lower, upper, body_fun, init_val): """Equivalent to ``jax.lax.fori_loop`` with Haiku state threaded in/out.""" if not base.inside_transform(): raise ValueError( "hk.fori_loop() should not be used outside of hk.transform(). " "Use jax.lax.fori_loop() instead.") def pure_body_fun(i, val): state, val = val with temporary_internal_state(state): val = body_fun(i, val) state = internal_state() return state, val if not base.params_frozen(): # During init we need to unwind one step of the loop to ensure the Haiku # state before and after the body has the same structure. init_val = body_fun(lower, init_val) lower += 1 if upper - lower == 0: return init_val state = internal_state() init_val = state, init_val state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val) update_internal_state(state) return val
def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: _, y = stateful.while_loop(lambda a: a[0] < iters, lambda a: (a[0] + 1, m(a[1])), (0, x)) return y
def scan(f, init, xs, length=None, reverse=False): """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out.""" if not base.inside_transform(): raise ValueError( "hk.scan() should not be used outside of hk.transform(). " "Use jax.scan() instead.") if length is None: length = jax.tree_leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() if running_init_fn: # During `init` we need to unroll one step of the scan, this is because our # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype), y0) return init, y0 if reverse: x0 = jax.tree_map(lambda x: x[-1], xs) xs = jax.tree_map(lambda x: x[:-1], xs) else: x0 = jax.tree_map(lambda x: x[0], xs) xs = jax.tree_map(lambda x: x[1:], xs) init, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(): carry, out = f(carry, x) carry = (carry, internal_state(params=False)) return carry, out # We know that we don't need to thread params in and out, since for init we # have already created them (given that above we unroll one step of the scan) # and for apply we know they are immutable. As such we only need to thread the # state and rng in and out. init = (init, internal_state(params=False)) (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse) update_internal_state(state) if running_init_fn: ys = jax.tree_multimap(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys
def __call__(self, *args, **kwargs): frame = base.current_frame() bundle_name = self.module_name if base.params_frozen(): prefix = bundle_name + "/" lifted_params = unpack_from_dict(frame.params, prefix) return lifted_params else: # Inside init. # Lift parameters into this transform's params_dict. params = self._init_fn(*args, **kwargs) pack_into_dict(params, frame.params, bundle_name) return params
def conv_weight_with_spectral_norm(x: jnp.ndarray, kernel_shape: Sequence[int], out_channel: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, use_bias: bool = True, is_training: bool = True, update_params: bool = True, max_singular_value: float = 0.95, max_power_iters: int = 1, **conv_kwargs): batch_size, H, W, C = x.shape w_shape = kernel_shape + (C, out_channel) w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init) u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel), init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (H, W, C), init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, max_power_iters, update_params) # Run for a lot of steps when we're first initializing running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), x.dtype, init=b_init) if use_bias: return w, b return w
def weight_with_spectral_norm(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, update_params: bool = True, use_bias: bool = True, force_in_dim: Optional = None, max_singular_value: float = 0.99, max_power_iters: int = 1, **kwargs): in_dim, dtype = x.shape[-1], x.dtype if force_in_dim: in_dim = force_in_dim w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), dtype, init=b_init) u = hk.get_state(f"u_{name_suffix}", (out_dim, ), dtype, init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (in_dim, ), dtype, init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, max_power_iters, update_params) running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: return w, b return w
def while_loop(cond_fun, body_fun, init_val): """Equivalent to jax.lax.while_loop with Haiku state threaded in/out.""" if not base.params_frozen(): raise ValueError( "hk.while_loop does not support initialization (since we cannot " "statically determine if your loop will run at least once). Please " "use `hk.running_init` to run the body unconditionally:\n" "\n" " if hk.running_init():\n" " # Unconditionally connect the module at init time.\n" " val = module(val)\n" " else:\n" " val = hk.while_loop(lambda val: val.mean() < 1, module, val)\n" ) @functools.wraps(cond_fun) def pure_cond_fun(val): val, _ = val try: with base.assert_state_unchanged(): return cond_fun(val) except base.StateChangedError as e: # If we find a use case for updating state/using rng in `cond` we would # need to make a change in JAX itself (to support aux in/out of the cond). raise ValueError( "`hk.while_loop` does not support `hk.set_state`, `hk.next_rng_key` " "(et al) in `cond_fun`.") from e @functools.wraps(body_fun) def pure_body_fun(val): val, state = val with temporary_internal_state(state), \ base.push_jax_trace_level(): val = body_fun(val) state = internal_state() return val, state init_val = (init_val, internal_state()) val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val) update_internal_state(state) return val
def update_frame_data(params, state): frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, params) update_recursive_skip_none(frame.state, state)
def init_if_needed(self, x, rng): # Before extracting the frame data, we need to make sure that the # network is initialized! running_init_fn = not hk_base.params_frozen() if running_init_fn: self.auto_batched_res_block(x, rng)
def scan(f, init, xs, length=None, reverse=False, unroll=1): """Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.scan() should not be used outside of hk.transform(). " "Use jax.scan() instead.") if length is None: length = jax.tree_leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() if running_init_fn: # During `init` we need to unroll one step of the scan, this is because our # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype), y0) return init, y0 if reverse: x0 = jax.tree_map(lambda x: x[-1], xs) xs = jax.tree_map(lambda x: x[:-1], xs) else: x0 = jax.tree_map(lambda x: x[0], xs) xs = jax.tree_map(lambda x: x[1:], xs) init, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 @functools.wraps(f) def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(), \ base.push_jax_trace_level(): carry, out = f(carry, x) reserve_up_to_full_rng_block() carry = (carry, internal_state(params=False)) return carry, out # Before pulling out the internal state, reserve a full block of RNG keys. # This is to make sure we're always passing in the same amount of subkeys in # and out of the scan carry (scan requires equal length lists). # After every scan iteration we reserve back up to the full block. reserve_up_to_full_rng_block() # We know that we don't need to thread params in and out, since for init we # have already created them (given that above we unroll one step of the scan) # and for apply we know they are immutable. As such we only need to thread the # state and rng in and out. init = (init, internal_state(params=False)) (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse, unroll=unroll) update_internal_state(state) if running_init_fn: if reverse: ys = jax.tree_map(lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys) else: ys = jax.tree_map(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys
import itertools as it from absl.testing import absltest from absl.testing import parameterized from haiku._src import base from haiku._src import base_test from haiku._src import module from haiku._src import stateful from haiku._src import test_utils from haiku._src import transform import jax import jax.numpy as jnp import numpy as np toggle = lambda i, a: lambda x: a(x) if base.params_frozen() else i(x) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda HK_OVERLOADED_JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", stateful.jit), # ("make_jaxpr", stateful.make_jaxpr), ("eval_shape", lambda f: (lambda x: [f(x), stateful.eval_shape(f, x)])), ("named_call", stateful.named_call), # Parallelization. # TODO(tomhennigan): Add missing features (e.g. pjit,xmap).