Пример #3
def transform_with_state(f) -> TransformedWithState:
    """Transforms a function using Haiku modules into a pair of pure functions.

  See :func:`transform` for general details on Haiku transformations.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values and state::

      params, state = init(rng, *a, **k)
      out, state = apply(params, state, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for `apply` and
  passing ``None`` is accepted.

  This function is equivalent to :func:`transform`, however it allows you to
  maintain and update internal state (e.g. moving averages in batch norm) via
  :func:`get_state` and :func:`set_state`.

  >>> def f():
  ...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
  ...                          init=jnp.zeros)
  ...   hk.set_state("counter", counter + 1)
  ...   return counter

  >>> f = hk.transform_with_state(f)

  >>> params, state = f.init(None)
  >>> for _ in range(10):
  ...   counter, state = f.apply(params, state, None)
  >>> counter
  DeviceArray(9, dtype=int32)

    f: A function closing over :class:`Module` instances.

    A :class:`TransformedWithState` tuple with `init` and `apply` properties.

    def init_fn(
        rng: Optional[Union[PRNGKey, PRNGSeed]],
    ) -> Tuple[Params, State]:
        """Initializes your function collecting parameters and state."""
        rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
        with base.new_context(rng=rng) as ctx:
            f(*args, **kwargs)
        return ctx.collect_params(), ctx.collect_initial_state()

    def apply_fn(
        params: Params,
        state: State,
        rng: Optional[Union[PRNGKey, PRNGSeed]],
    ) -> Tuple[Any, State]:
        """Applies your function injecting parameters and state."""
        params = check_mapping("params", params)
        state = check_mapping("state", state)
        rng = to_prng_sequence(
            rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
        with base.new_context(params=params, state=state, rng=rng) as ctx:
            out = f(*args, **kwargs)
        return out, ctx.collect_state()

    # EXPERIMENTAL: Expose the original function as a private attribute.
    init_fn._original_fn = f  # pylint: disable=protected-access
    apply_fn._original_fn = f  # pylint: disable=protected-access

    return TransformedWithState(init_fn, apply_fn)
Пример #6
def multi_transform_with_state(
    f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]],
) -> MultiTransformedWithState:
    """Transforms a collection of functions using Haiku into pure functions.

  See :func:`multi_transform` for more details.


  >>> def f():
  ...   encoder = hk.Linear(1, name="encoder")
  ...   decoder = hk.Linear(1, name="decoder")
  ...   def init(x):
  ...     z = encoder(x)
  ...     return decoder(z)
  ...   return init, (encoder, decoder)

  >>> f = hk.multi_transform_with_state(f)
  >>> rng = jax.random.PRNGKey(42)
  >>> x = jnp.ones([1, 1])
  >>> params, state = f.init(rng, x)
  >>> jax.tree_map(jnp.shape, params)
  {'decoder': {'b': (1,), 'w': (1, 1)},
   'encoder': {'b': (1,), 'w': (1, 1)}}

  >>> encode, decode = f.apply
  >>> z, state = encode(params, state, None, x)
  >>> y, state = decode(params, state, None, z)

    f: Function returning a "template" function and an arbitrary
      tree of functions using modules connected in the template function.

    An ``init`` function and a tree of pure ``apply`` functions.

  See also:
    :func:`transform_with_state`: Transform a single apply function.
    :func:`multi_transform`: Transform multiple apply functions without state.

    def init_fn(*args, **kwargs):
        """Returns initial state for the transformed functions."""
        return f()[0](*args, **kwargs)

    init_fn = hk.transform_with_state(init_fn).init

    def apply_fn_i(i):
        def apply_fn(*args, **kwargs):
            """Applies the transformed function at the given inputs."""
            return jax.tree_leaves(f()[1])[i](*args, **kwargs)

        return apply_fn

    # We need to find out the structure of f()[1], including how many
    # functions there are, so that we can transform them individually and repack
    # into the same tree structure. It's valid for modules to declare parameters
    # in their constructor, so we need to create something that looks like
    # hk.Params in order to do this. `jax.eval_shape` interprets the function
    # abstractly, ie no real params are created, and we don't need to touch the
    # accelerator. This means hardcoding the RNG below is fine.
    def get_output_treedef() -> Box:
        rng = jax.random.PRNGKey(42)  # This is fine, see above
        fns = hk.transform_with_state(lambda: f()[1])
        apply_fns, _ = fns.apply(*fns.init(rng), rng)
        return Box(jax.tree_structure(apply_fns))

    output_treedef = jax.eval_shape(get_output_treedef).python_value
    apply_fns = make_tree(
        lambda i: hk.transform_with_state(apply_fn_i(i)).apply, output_treedef)

    return MultiTransformedWithState(init_fn, apply_fns)
Пример #7
def multi_transform(
    f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]], ) -> MultiTransformed:
    """Transforms a collection of functions using Haiku into pure functions.

  In many scenarios we have several modules which are used either as primitives
  for several Haiku modules/functions, or whose pure versions are to be reused
  in downstream code. This utility enables this by applying
  :func:`transform` to an arbitrary tree of Haiku functions which share modules
  and have a common ``init`` function.

  ``f`` is expected to return a tuple of two elements. First is a ``template``
  Haiku function which provides an example of how all internal Haiku modules are
  connected. This function is used to create a common ``init`` function (with
  your parameters).

  The second object is an arbitrary tree of Haiku functions all of which reuse
  the modules connected in the ``template`` function. These functions are
  transformed to pure ``apply`` functions.


  >>> def f():
  ...   encoder = hk.Linear(1, name="encoder")
  ...   decoder = hk.Linear(1, name="decoder")
  ...   def init(x):
  ...     z = encoder(x)
  ...     return decoder(z)
  ...   return init, (encoder, decoder)

  >>> f = hk.multi_transform(f)
  >>> rng = jax.random.PRNGKey(42)
  >>> x = jnp.ones([1, 1])
  >>> params = f.init(rng, x)
  >>> jax.tree_map(jnp.shape, params)
  {'decoder': {'b': (1,), 'w': (1, 1)},
   'encoder': {'b': (1,), 'w': (1, 1)}}

  >>> encode, decode = f.apply
  >>> z = encode(params, None, x)
  >>> y = decode(params, None, z)

    f: A factory function that returns two functions, firstly a common init
      function that creates all modules, and secondly a pytree of apply
      functions which make use of those modules.

    A :class:`MultiTransformed` instance which contains a pure init function
      that creates all parameters, and a pytree of pure apply functions that
      given the params apply the given function.

  See also:
    :func:`multi_transform_with_state`: Equivalent for modules using state.

    f = multi_transform_with_state(f)
    f = without_state(f)
    return f
