Example #1
0
  def wrapper(*args, **kwargs):
    base.assert_context("optimize_rng_use")

    # Extract all current state.
    frame = base.current_frame()
    params = frame.params or None
    if params is not None:
      params = data_structures.to_haiku_dict(params)
    state = frame.state or None
    if state is not None:
      state = base.extract_state(state, initial=True)
    rng = frame.rng_stack.peek()
    if rng is not None:
      rng = rng.internal_state

    def pure_fun(params, state, rng, *args, **kwargs):
      with base.new_context(params=params, state=state, rng=rng):
        return fun(*args, **kwargs)

    with count_hk_rngs_requested() as rng_count_f:
      jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
    rng_count = rng_count_f()

    if rng_count:
      base.current_frame().rng_stack.peek().reserve(rng_count)
    return fun(*args, **kwargs)
Example #2
0
def running_init() -> bool:
    """Return True if running the ``init`` function of a Haiku transform.

  In general you should not need to gate behaviour of your module based on
  whether you are running ``init`` or ``apply``, but sometimes (e.g. when making
  use of JAX control flow) this is required.

  For example, if you want to use :func:`switch` to pick between experts, when
  we run your init function we need to ensure that params/state for all experts
  are created (unconditionally) but during apply we want to conditionally apply
  (and perhaps update the internal state) of only one of our experts:

  >>> experts = [hk.nets.ResNet50(10) for _ in range(5)]
  >>> x = jnp.ones([1, 224, 224, 3])
  >>> if hk.running_init():
  ...   # During init unconditionally create params/state for all experts.
  ...   for expert in experts:
  ...     out = expert(x, is_training=True)
  ... else:
  ...   # During apply conditionally apply (and update) only one expert.
  ...   index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
  ...   out = hk.switch(index, experts, x)

  Returns:
    True if running ``init`` otherwise False.
  """
    base.assert_context("running_init")
    return not base.params_frozen()
Example #3
0
    def mapped_fun(*args):
        base.assert_context("vmap")

        mapped_pure_fun = jax.vmap(pure_fun,
                                   in_axes=in_axes,
                                   out_axes=out_axes,
                                   axis_name=axis_name,
                                   axis_size=axis_size)
        state = internal_state()

        if split_rng:
            # Need to take a new key and split.
            num = get_mapped_axis_size(args, in_axes[0])
            rng = base.next_rng_keys(num)
            state = internal_state()  # Needed since we mutated internal RNG.
            saved_rng = state.rng
            state = InternalState(state.params, state.state, rng)

        out, state = mapped_pure_fun(args, state)

        if split_rng:
            state = InternalState(state.params, state.state, saved_rng)

        update_internal_state(state)

        return out
Example #4
0
def intercept_methods(interceptor: MethodGetter):
    """Register a new method interceptor.

  Method interceptors allow you to (at a distance) intercept method calls to
  modules and modify args/kwargs before calling the underlying method. After the
  underlying method is called you can modify its result before it is passed back
  to the user.

  For example you could intercept method calls to :class:`~haiku.BatchNorm` and
  ensure it is always computed in full precision:

  >>> def my_interceptor(next_f, args, kwargs, context):
  ...   if (type(context.module) is not hk.BatchNorm
  ...       or context.method_name != "__call__"):
  ...     # We ignore methods other than BatchNorm.__call__.
  ...     return next_f(*args, **kwargs)
  ...
  ...   def cast_if_array(x):
  ...     if isinstance(x, jnp.ndarray):
  ...       x = x.astype(jnp.float32)
  ...     return x
  ...
  ...   args, kwargs = jax.tree_map(cast_if_array, (args, kwargs))
  ...   out = next_f(*args, **kwargs)
  ...   return out

  We can create and use our module in the usual way, we just need to wrap any
  method calls we want to intercept in the context manager:

  >>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True)
  >>> x = jnp.ones([], jnp.bfloat16)
  >>> with hk.experimental.intercept_methods(my_interceptor):
  ...   out = mod(x, is_training=True)
  >>> assert out.dtype == jnp.float32

  Without the interceptor BatchNorm would compute in bf16, however since we
  cast `x` before the underlying method is called we compute in f32.

  Args:
    interceptor: A method interceptor.

  Returns:
    Context manager under which the interceptor is active.
  """
    base.assert_context("experimental.intercept_methods")
    return interceptor_stack(interceptor)
Example #5
0
def name_scope(name: str) -> ContextManager[None]:
    """Context manager which adds a prefix to all new modules, params or state.

  >>> with hk.experimental.name_scope("my_name_scope"):
  ...   net = hk.Linear(1, name="my_linear")
  >>> net.module_name
  'my_name_scope/my_linear'

  When used inside a module, any submodules, parameters or state created inside
  the name scope will have a prefix added to their names:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     with hk.experimental.name_scope("my_name_scope"):
  ...       submodule = hk.Linear(1, name="submodule")
  ...       w = hk.get_parameter("w", [], init=jnp.ones)
  ...     return submodule(x) + w

  >>> f = hk.transform(lambda x: MyModule()(x))
  >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
  >>> jax.tree_map(jnp.shape, params)
  FlatMapping({
    'my_module/my_name_scope': FlatMapping({'w': ()}),
    'my_module/my_name_scope/submodule': FlatMapping({'b': (1,), 'w': (1, 1)}),
  })

  Name scopes are very similar to putting all of the code inside the context
  manager inside a method on a :class:`Module` with the name you provide. Behind
  the scenes this is precisely how name scopes are implemented.

  If you are familiar with TensorFlow then Haiku's :func:`name_scope` is similar
  to ``tf.variable_scope(..)`` in TensorFlow 1 and ``tf.name_scope(..)`` in
  TensorFlow 1 and 2 in that it changes the names associated with modules,
  parameters and state.

  Args:
    name: The name scope to use (e.g. ``"foo"`` or ``"foo/bar"``).

  Returns:
    A single use context manager that when active prefixes new modules,
    parameters or state with the given name.
  """
    base.assert_context("experimental.name_scope")
    return NameScope(name)
Example #6
0
def lift(
    init_fn: Callable[..., hk.Params],
    name: str = "lifted",
) -> Callable[..., hk.Params]:
    r"""Lifts the given init fn to a function in the current Haiku namespace.

  During init, the returned callable will run the given ``init_fn``, and include
  the resulting params in the outer transform's dictionaries.
  During ``apply``, the returned callable will instead pull the relevant
  parameters from the outer transform's dictionaries.

  Must be called inside :func:`transform`\ , and be passed the ``init``
  member of a :class:`Transformed`\ .

  The user must ensure that the given ``init`` does not accidentally catch
  modules from an outer :func:`transform` via functional closure.

  Example:

    >>> def g(x):
    ...   return hk.Linear(1, name='g_linear')(x)
    >>> g = hk.transform(g)
    >>> init_rng = hk.next_rng_key() if hk.running_init() else None
    >>> x = jnp.ones([1, 1])
    >>> params = hk.lift(g.init, name='f_lift')(init_rng, x)
    >>> out = g.apply(params, None, x)

  Args:
    init_fn: The ``init`` function from an :class:`Transformed`\ .
    name: A string name to prefix parameters with.

  Returns:
    A callable that during ``init`` injects parameter values into the outer
    context and during ``apply`` reuses parameters from the outer context. In
    both cases returns parameter values to be used with an ``apply`` function.
  """
    base.assert_context("lift")
    lifted = LiftingModule(init_fn, name=name)
    # NOTE: Using lambda to avoid exposing module object.
    return lambda *a, **k: lifted(*a, **k)  # pylint: disable=unnecessary-lambda
Example #7
0
def transparent_lift_with_state(
    init_fn: Callable[..., Tuple[hk.Params, hk.State]],
    *,
    allow_reuse: bool = False
) -> Tuple[Callable[..., Tuple[hk.Params, hk.State]], LiftWithStateUpdater]:
    r"""Registers params and state in an outer transform without adding scope.

  Functionally this is equivalent to :func:`lift_with_state`\ but without
  automatically adding an additional variable scoping.

  See :func:`lift_with_state`\ for more context on when to use
  ``lift_with_state``.

  Args:
    init_fn: The ``init`` function from an :class:`TransformedWithState`\.
    allow_reuse: Allows lifted parameters and state to be reused from the outer
      :func:`transform_with_state`\ . This can be desirable when e.g. within
      control flow (e.g. ``hk.scan``).

  Returns:
    A callable that during ``init`` injects parameter values into the outer
    context and during ``apply`` reuses parameters from the outer context. In
    both cases returns parameter values to be used with an ``apply`` function.
    The ``init`` function additionally returns an object used to update the
    outer context with new state after ``apply`` is called.

  See also:
    :func:`~haiku.lift`: Register params with an outer transform.
    :func:`lift_with_state`: Register params and state with an outer transform.
    :func:`transparent_lift`: Register params with an outer transform
      without a namespace.
  """

    base.assert_context("lift_with_state")
    params_and_state_fn = _to_callable(
        LiftingModule(init_fn, transparent=True, allow_reuse=allow_reuse))
    name = base.current_bundle_name() if base.current_module() else None
    updater = LiftWithStateUpdater(name)
    return params_and_state_fn, updater
Example #8
0
def transparent_lift(init_fn: Callable[..., hk.Params],
                     *,
                     allow_reuse: bool = False) -> Callable[..., hk.Params]:
    r"""Registers parameters in an outer transform without adding a name scope.

  Functionally this is equivalent to :func:`lift`\ but without automatically
  adding an additional variable scoping. Note that closing over a module from
  an outer scope is disallowed.

  See :func:`lift`\ for more context on when to use ``lift``.

  Args:
    init_fn: The ``init`` function from an :class:`Transformed`\ .
    allow_reuse: Allows lifted parameters to be reused from the outer
      :func:`transform_with_state`\ . This can be desirable when e.g. within
      control flow (e.g. ``hk.scan``).

  Returns:
    A callable that during ``init`` injects parameter values into the outer
    context and during ``apply`` reuses parameters from the outer context. In
    both cases returns parameter values to be used with an ``apply`` function.

  See also:
    :func:`~haiku.lift`: Register params with an outer transform.
    :func:`lift_with_state`: Register params and state with an outer transform.
    :func:`transparent_lift_with_state`: Register params and state with an
      outer transform without a namespace.
  """

    base.assert_context("transparent_lift")
    init_fn = add_state_to_init_fn(init_fn)
    lifted = LiftingModule(init_fn, transparent=True, allow_reuse=allow_reuse)

    def fn(*a, **k):
        with base.closure_boundary_stack(base.current_frame().frame_id + 1):
            return lifted(*a, **k)[0]

    return fn
Example #9
0
def lift_with_state(
    init_fn: Callable[..., Tuple[hk.Params, hk.State]],
    *,
    allow_reuse: bool = False,
    name: str = "lifted",
) -> Tuple[Callable[..., Tuple[hk.Params, hk.State]], LiftWithStateUpdater]:
    r"""Registers params and state from an init function in an outer transform.

  See :func:`lift`\ for more context on when to use ``lift``.

  This function returns two objects. The first is a callable that runs your init
  function with slightly different behaviour based on if it's run during init
  vs. apply time. The second is an updater that can be used to pass updated
  state values that result from running your apply function. See later in the
  docs for a worked example.

  During init, the returned callable will run the given ``init_fn``, and include
  the resulting params/state in the outer transform's dictionaries. During
  ``apply``, the returned callable will instead pull the relevant params/state
  from the outer transform's dictionaries.

  Must be called inside :func:`transform_with_state`\ , and be passed the
  ``init`` member of a :class:`TransformedWithState`\ .

  By default, users must ensure that the given ``init`` does not accidentally
  catch modules from an outer :func:`transform_with_state` via functional
  closure. If this behavior is desirable, set ``allow_reuse`` to ``True``.

  Example:

    >>> def g(x):
    ...   return hk.nets.ResNet50(1)(x, True)
    >>> g = hk.transform_with_state(g)
    >>> params_and_state_fn, updater = (
    ...   hk.experimental.lift_with_state(g.init, name='f_lift'))
    >>> init_rng = hk.next_rng_key() if hk.running_init() else None
    >>> x = jnp.ones([1, 224, 224, 3])
    >>> params, state = params_and_state_fn(init_rng, x)
    >>> out, state = g.apply(params, state, None, x)
    >>> updater.update(state)

  Args:
    init_fn: The ``init`` function from an :class:`TransformedWithState`\ .
    allow_reuse: Allows lifted parameters and state to be reused from the
      outer :func:`transform_with_state`. This can be desirable when using
      ``lift_with_state`` within control flow (e.g. ``hk.scan``).
    name: A string name to prefix parameters with.

  Returns:
    A callable that during ``init`` injects parameter values into the outer
    context and during ``apply`` reuses parameters from the outer context. In
    both cases returns parameter values to be used with an ``apply`` function.
    The ``init`` function additionally returns an object used to update the
    outer context with new state after ``apply`` is called.

  See also:
    :func:`~haiku.lift`: Register parameters with an outer transform.
    :func:`transparent_lift`: Register parameters with an outer transform
      without a namespace.
    :func:`transparent_lift_with_state`: Register parameters and state with an
      outer transform without a namespace.
  """
    base.assert_context("experimental.lift_with_state")
    params_and_state_fn = _to_callable(
        LiftingModule(init_fn, allow_reuse=allow_reuse, name=name))
    if base.current_module():
        name = f"{base.current_bundle_name()}/{name}"
    updater = LiftWithStateUpdater(name)
    return params_and_state_fn, updater
Example #10
0
def lift(
    init_fn: Callable[..., hk.Params],
    *,
    allow_reuse: bool = False,
    name: str = "lifted",
) -> Callable[..., hk.Params]:
    r"""Registers parameters from an inner init function in an outer transform.

  Use :func:`lift`\ when nesting Haiku transforms to register the parameters of
  the inner transform in any outer transform. This is mainly useful when using
  JAX functions inside of a Haiku module (eg. using ``jax.vmap`` on a layer).
  See
  https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html#Using-hk.lift
  for more explanation of when to use :func:`lift`\ . (If you're not using JAX
  functions inside of a module or don't need access to your parameters inside of
  a transform, you probably don't need to use :func:`lift`\ )

  Must be called inside :func:`transform`\ , and be passed the ``init``
  member of a :class:`Transformed`\ .

  During init, the returned callable will run the given ``init_fn``, and include
  the resulting params in the outer transform's dictionaries.
  During ``apply``, the returned callable will instead pull the relevant
  parameters from the outer transform's dictionaries.

  By default, users must ensure that the given ``init`` does not accidentally
  catch modules from an outer :func:`transform` via functional
  closure. If this behavior is desirable, set ``allow_reuse`` to ``True``.

  Example:

    >>> # outer can be `hk.transform`ed and will contain the params of inner.
    >>> def outer(x):
    ...   @hk.transform
    ...   def inner(x):
    ...     return hk.Linear(1)(x)
    ...   init_rng = hk.next_rng_key() if hk.running_init() else None
    ...   x = jnp.ones([1, 1])
    ...   params = hk.lift(inner.init, name='f_lift')(init_rng, x)
    ...   # inner.apply is a pure function and can be vmapped.
    ...   return jax.vmap(inner.apply, in_axes=(0, None, 0))(params, None, x)

  Args:
    init_fn: The ``init`` function from an :class:`Transformed`\ .
    allow_reuse: Allows lifted parameters and state to be reused from the
      outer :func:`transform`. This can be desirable when using ``lift`` within
      control flow (e.g. ``hk.scan``).
    name: A string name to prefix parameters with.

  Returns:
    A callable that during ``init`` injects parameter values into the outer
    context and during ``apply`` retrieves parameters from the outer context. In
    both cases returns parameter values to be used with an ``apply`` function.

  See also:
    :func:`~haiku.experimental.lift_with_state`: Register params and state
      with an outer transform.
    :func:`~haiku.experimental.transparent_lift`: Register params with an
      outer transform without a namespace.
    :func:`~haiku.experimental.transparent_lift_with_state`: Register params
      and state with an outer transform without a namespace.
  """
    base.assert_context("lift")
    init_fn = add_state_to_init_fn(init_fn)
    params_and_state_fn, updater = lift_with_state(init_fn,
                                                   allow_reuse=allow_reuse,
                                                   name=name)
    updater.ignore_update()
    return lambda *a, **k: params_and_state_fn(*a, **k)[0]