def wrapper(*args, **kwargs): base.assert_context("optimize_rng_use") # Extract all current state. frame = base.current_frame() params = frame.params or None if params is not None: params = data_structures.to_haiku_dict(params) state = frame.state or None if state is not None: state = base.extract_state(state, initial=True) rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state def pure_fun(params, state, rng, *args, **kwargs): with base.new_context(params=params, state=state, rng=rng): return fun(*args, **kwargs) with count_hk_rngs_requested() as rng_count_f: jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs) rng_count = rng_count_f() if rng_count: base.current_frame().rng_stack.peek().reserve(rng_count) return fun(*args, **kwargs)
def running_init() -> bool: """Return True if running the ``init`` function of a Haiku transform. In general you should not need to gate behaviour of your module based on whether you are running ``init`` or ``apply``, but sometimes (e.g. when making use of JAX control flow) this is required. For example, if you want to use :func:`switch` to pick between experts, when we run your init function we need to ensure that params/state for all experts are created (unconditionally) but during apply we want to conditionally apply (and perhaps update the internal state) of only one of our experts: >>> experts = [hk.nets.ResNet50(10) for _ in range(5)] >>> x = jnp.ones([1, 224, 224, 3]) >>> if hk.running_init(): ... # During init unconditionally create params/state for all experts. ... for expert in experts: ... out = expert(x, is_training=True) ... else: ... # During apply conditionally apply (and update) only one expert. ... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1) ... out = hk.switch(index, experts, x) Returns: True if running ``init`` otherwise False. """ base.assert_context("running_init") return not base.params_frozen()
def mapped_fun(*args): base.assert_context("vmap") mapped_pure_fun = jax.vmap(pure_fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, axis_size=axis_size) state = internal_state() if split_rng: # Need to take a new key and split. num = get_mapped_axis_size(args, in_axes[0]) rng = base.next_rng_keys(num) state = internal_state() # Needed since we mutated internal RNG. saved_rng = state.rng state = InternalState(state.params, state.state, rng) out, state = mapped_pure_fun(args, state) if split_rng: state = InternalState(state.params, state.state, saved_rng) update_internal_state(state) return out
def intercept_methods(interceptor: MethodGetter): """Register a new method interceptor. Method interceptors allow you to (at a distance) intercept method calls to modules and modify args/kwargs before calling the underlying method. After the underlying method is called you can modify its result before it is passed back to the user. For example you could intercept method calls to :class:`~haiku.BatchNorm` and ensure it is always computed in full precision: >>> def my_interceptor(next_f, args, kwargs, context): ... if (type(context.module) is not hk.BatchNorm ... or context.method_name != "__call__"): ... # We ignore methods other than BatchNorm.__call__. ... return next_f(*args, **kwargs) ... ... def cast_if_array(x): ... if isinstance(x, jnp.ndarray): ... x = x.astype(jnp.float32) ... return x ... ... args, kwargs = jax.tree_map(cast_if_array, (args, kwargs)) ... out = next_f(*args, **kwargs) ... return out We can create and use our module in the usual way, we just need to wrap any method calls we want to intercept in the context manager: >>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True) >>> x = jnp.ones([], jnp.bfloat16) >>> with hk.experimental.intercept_methods(my_interceptor): ... out = mod(x, is_training=True) >>> assert out.dtype == jnp.float32 Without the interceptor BatchNorm would compute in bf16, however since we cast `x` before the underlying method is called we compute in f32. Args: interceptor: A method interceptor. Returns: Context manager under which the interceptor is active. """ base.assert_context("experimental.intercept_methods") return interceptor_stack(interceptor)
def name_scope(name: str) -> ContextManager[None]: """Context manager which adds a prefix to all new modules, params or state. >>> with hk.experimental.name_scope("my_name_scope"): ... net = hk.Linear(1, name="my_linear") >>> net.module_name 'my_name_scope/my_linear' When used inside a module, any submodules, parameters or state created inside the name scope will have a prefix added to their names: >>> class MyModule(hk.Module): ... def __call__(self, x): ... with hk.experimental.name_scope("my_name_scope"): ... submodule = hk.Linear(1, name="submodule") ... w = hk.get_parameter("w", [], init=jnp.ones) ... return submodule(x) + w >>> f = hk.transform(lambda x: MyModule()(x)) >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1])) >>> jax.tree_map(jnp.shape, params) FlatMapping({ 'my_module/my_name_scope': FlatMapping({'w': ()}), 'my_module/my_name_scope/submodule': FlatMapping({'b': (1,), 'w': (1, 1)}), }) Name scopes are very similar to putting all of the code inside the context manager inside a method on a :class:`Module` with the name you provide. Behind the scenes this is precisely how name scopes are implemented. If you are familiar with TensorFlow then Haiku's :func:`name_scope` is similar to ``tf.variable_scope(..)`` in TensorFlow 1 and ``tf.name_scope(..)`` in TensorFlow 1 and 2 in that it changes the names associated with modules, parameters and state. Args: name: The name scope to use (e.g. ``"foo"`` or ``"foo/bar"``). Returns: A single use context manager that when active prefixes new modules, parameters or state with the given name. """ base.assert_context("experimental.name_scope") return NameScope(name)
def lift( init_fn: Callable[..., hk.Params], name: str = "lifted", ) -> Callable[..., hk.Params]: r"""Lifts the given init fn to a function in the current Haiku namespace. During init, the returned callable will run the given ``init_fn``, and include the resulting params in the outer transform's dictionaries. During ``apply``, the returned callable will instead pull the relevant parameters from the outer transform's dictionaries. Must be called inside :func:`transform`\ , and be passed the ``init`` member of a :class:`Transformed`\ . The user must ensure that the given ``init`` does not accidentally catch modules from an outer :func:`transform` via functional closure. Example: >>> def g(x): ... return hk.Linear(1, name='g_linear')(x) >>> g = hk.transform(g) >>> init_rng = hk.next_rng_key() if hk.running_init() else None >>> x = jnp.ones([1, 1]) >>> params = hk.lift(g.init, name='f_lift')(init_rng, x) >>> out = g.apply(params, None, x) Args: init_fn: The ``init`` function from an :class:`Transformed`\ . name: A string name to prefix parameters with. Returns: A callable that during ``init`` injects parameter values into the outer context and during ``apply`` reuses parameters from the outer context. In both cases returns parameter values to be used with an ``apply`` function. """ base.assert_context("lift") lifted = LiftingModule(init_fn, name=name) # NOTE: Using lambda to avoid exposing module object. return lambda *a, **k: lifted(*a, **k) # pylint: disable=unnecessary-lambda
def transparent_lift_with_state( init_fn: Callable[..., Tuple[hk.Params, hk.State]], *, allow_reuse: bool = False ) -> Tuple[Callable[..., Tuple[hk.Params, hk.State]], LiftWithStateUpdater]: r"""Registers params and state in an outer transform without adding scope. Functionally this is equivalent to :func:`lift_with_state`\ but without automatically adding an additional variable scoping. See :func:`lift_with_state`\ for more context on when to use ``lift_with_state``. Args: init_fn: The ``init`` function from an :class:`TransformedWithState`\. allow_reuse: Allows lifted parameters and state to be reused from the outer :func:`transform_with_state`\ . This can be desirable when e.g. within control flow (e.g. ``hk.scan``). Returns: A callable that during ``init`` injects parameter values into the outer context and during ``apply`` reuses parameters from the outer context. In both cases returns parameter values to be used with an ``apply`` function. The ``init`` function additionally returns an object used to update the outer context with new state after ``apply`` is called. See also: :func:`~haiku.lift`: Register params with an outer transform. :func:`lift_with_state`: Register params and state with an outer transform. :func:`transparent_lift`: Register params with an outer transform without a namespace. """ base.assert_context("lift_with_state") params_and_state_fn = _to_callable( LiftingModule(init_fn, transparent=True, allow_reuse=allow_reuse)) name = base.current_bundle_name() if base.current_module() else None updater = LiftWithStateUpdater(name) return params_and_state_fn, updater
def transparent_lift(init_fn: Callable[..., hk.Params], *, allow_reuse: bool = False) -> Callable[..., hk.Params]: r"""Registers parameters in an outer transform without adding a name scope. Functionally this is equivalent to :func:`lift`\ but without automatically adding an additional variable scoping. Note that closing over a module from an outer scope is disallowed. See :func:`lift`\ for more context on when to use ``lift``. Args: init_fn: The ``init`` function from an :class:`Transformed`\ . allow_reuse: Allows lifted parameters to be reused from the outer :func:`transform_with_state`\ . This can be desirable when e.g. within control flow (e.g. ``hk.scan``). Returns: A callable that during ``init`` injects parameter values into the outer context and during ``apply`` reuses parameters from the outer context. In both cases returns parameter values to be used with an ``apply`` function. See also: :func:`~haiku.lift`: Register params with an outer transform. :func:`lift_with_state`: Register params and state with an outer transform. :func:`transparent_lift_with_state`: Register params and state with an outer transform without a namespace. """ base.assert_context("transparent_lift") init_fn = add_state_to_init_fn(init_fn) lifted = LiftingModule(init_fn, transparent=True, allow_reuse=allow_reuse) def fn(*a, **k): with base.closure_boundary_stack(base.current_frame().frame_id + 1): return lifted(*a, **k)[0] return fn
def lift_with_state( init_fn: Callable[..., Tuple[hk.Params, hk.State]], *, allow_reuse: bool = False, name: str = "lifted", ) -> Tuple[Callable[..., Tuple[hk.Params, hk.State]], LiftWithStateUpdater]: r"""Registers params and state from an init function in an outer transform. See :func:`lift`\ for more context on when to use ``lift``. This function returns two objects. The first is a callable that runs your init function with slightly different behaviour based on if it's run during init vs. apply time. The second is an updater that can be used to pass updated state values that result from running your apply function. See later in the docs for a worked example. During init, the returned callable will run the given ``init_fn``, and include the resulting params/state in the outer transform's dictionaries. During ``apply``, the returned callable will instead pull the relevant params/state from the outer transform's dictionaries. Must be called inside :func:`transform_with_state`\ , and be passed the ``init`` member of a :class:`TransformedWithState`\ . By default, users must ensure that the given ``init`` does not accidentally catch modules from an outer :func:`transform_with_state` via functional closure. If this behavior is desirable, set ``allow_reuse`` to ``True``. Example: >>> def g(x): ... return hk.nets.ResNet50(1)(x, True) >>> g = hk.transform_with_state(g) >>> params_and_state_fn, updater = ( ... hk.experimental.lift_with_state(g.init, name='f_lift')) >>> init_rng = hk.next_rng_key() if hk.running_init() else None >>> x = jnp.ones([1, 224, 224, 3]) >>> params, state = params_and_state_fn(init_rng, x) >>> out, state = g.apply(params, state, None, x) >>> updater.update(state) Args: init_fn: The ``init`` function from an :class:`TransformedWithState`\ . allow_reuse: Allows lifted parameters and state to be reused from the outer :func:`transform_with_state`. This can be desirable when using ``lift_with_state`` within control flow (e.g. ``hk.scan``). name: A string name to prefix parameters with. Returns: A callable that during ``init`` injects parameter values into the outer context and during ``apply`` reuses parameters from the outer context. In both cases returns parameter values to be used with an ``apply`` function. The ``init`` function additionally returns an object used to update the outer context with new state after ``apply`` is called. See also: :func:`~haiku.lift`: Register parameters with an outer transform. :func:`transparent_lift`: Register parameters with an outer transform without a namespace. :func:`transparent_lift_with_state`: Register parameters and state with an outer transform without a namespace. """ base.assert_context("experimental.lift_with_state") params_and_state_fn = _to_callable( LiftingModule(init_fn, allow_reuse=allow_reuse, name=name)) if base.current_module(): name = f"{base.current_bundle_name()}/{name}" updater = LiftWithStateUpdater(name) return params_and_state_fn, updater
def lift( init_fn: Callable[..., hk.Params], *, allow_reuse: bool = False, name: str = "lifted", ) -> Callable[..., hk.Params]: r"""Registers parameters from an inner init function in an outer transform. Use :func:`lift`\ when nesting Haiku transforms to register the parameters of the inner transform in any outer transform. This is mainly useful when using JAX functions inside of a Haiku module (eg. using ``jax.vmap`` on a layer). See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html#Using-hk.lift for more explanation of when to use :func:`lift`\ . (If you're not using JAX functions inside of a module or don't need access to your parameters inside of a transform, you probably don't need to use :func:`lift`\ ) Must be called inside :func:`transform`\ , and be passed the ``init`` member of a :class:`Transformed`\ . During init, the returned callable will run the given ``init_fn``, and include the resulting params in the outer transform's dictionaries. During ``apply``, the returned callable will instead pull the relevant parameters from the outer transform's dictionaries. By default, users must ensure that the given ``init`` does not accidentally catch modules from an outer :func:`transform` via functional closure. If this behavior is desirable, set ``allow_reuse`` to ``True``. Example: >>> # outer can be `hk.transform`ed and will contain the params of inner. >>> def outer(x): ... @hk.transform ... def inner(x): ... return hk.Linear(1)(x) ... init_rng = hk.next_rng_key() if hk.running_init() else None ... x = jnp.ones([1, 1]) ... params = hk.lift(inner.init, name='f_lift')(init_rng, x) ... # inner.apply is a pure function and can be vmapped. ... return jax.vmap(inner.apply, in_axes=(0, None, 0))(params, None, x) Args: init_fn: The ``init`` function from an :class:`Transformed`\ . allow_reuse: Allows lifted parameters and state to be reused from the outer :func:`transform`. This can be desirable when using ``lift`` within control flow (e.g. ``hk.scan``). name: A string name to prefix parameters with. Returns: A callable that during ``init`` injects parameter values into the outer context and during ``apply`` retrieves parameters from the outer context. In both cases returns parameter values to be used with an ``apply`` function. See also: :func:`~haiku.experimental.lift_with_state`: Register params and state with an outer transform. :func:`~haiku.experimental.transparent_lift`: Register params with an outer transform without a namespace. :func:`~haiku.experimental.transparent_lift_with_state`: Register params and state with an outer transform without a namespace. """ base.assert_context("lift") init_fn = add_state_to_init_fn(init_fn) params_and_state_fn, updater = lift_with_state(init_fn, allow_reuse=allow_reuse, name=name) updater.ignore_update() return lambda *a, **k: params_and_state_fn(*a, **k)[0]