Пример #1
0
def transform_with_state(f) -> TransformedWithState:
    """Transforms a function using Haiku modules into a pair of pure functions.

  See :func:`transform` for general details on Haiku transformations.

  This function is equivalent to :func:`transform`, however it allows you to
  maintain and update internal state (e.g. moving averages in batch norm) via
  :func:`get_state` and :func:`set_state`.

  >>> def f():
  ...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
  ...                          init=jnp.zeros)
  ...   hk.set_state("counter", counter + 1)
  ...   return counter

  >>> f = hk.transform_with_state(f)

  >>> params, state = f.init(None)
  >>> for _ in range(10):
  ...   counter, state = f.apply(params, state, None)
  >>> counter
  DeviceArray(9, dtype=int32)

  Args:
    f: A function closing over `Module` instances.

  Returns:
    A named tuple with `init` and `apply` properties.
  """
    analytics.log_once("transform_with_state")
    return TransformedWithState(mk_init_fn(f), mk_apply_fn(f))
Пример #2
0
def transform_with_state(f) -> TransformedWithState:
    """Transforms a function using Haiku modules into a pair of pure functions.

  See :func:`transform` for general details on Haiku transformations.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values and state::

      params, state = init(rng, *a, **k)
      out, state = apply(params, state, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for `apply` and
  passing ``None`` is accpeted.

  This function is equivalent to :func:`transform`, however it allows you to
  maintain and update internal state (e.g. moving averages in batch norm) via
  :func:`get_state` and :func:`set_state`.

  >>> def f():
  ...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
  ...                          init=jnp.zeros)
  ...   hk.set_state("counter", counter + 1)
  ...   return counter

  >>> f = hk.transform_with_state(f)

  >>> params, state = f.init(None)
  >>> for _ in range(10):
  ...   counter, state = f.apply(params, state, None)
  >>> counter
  DeviceArray(9, dtype=int32)

  Args:
    f: A function closing over :class:`Module` instances.

  Returns:
    A :class:`TransformedWithState` tuple with `init` and `apply` properties.
  """
    analytics.log_once("transform_with_state")
    return TransformedWithState(make_init_fn(f), make_apply_fn(f))
Пример #3
0
def transform_with_state(f) -> TransformedWithState:
    """Transforms a function using Haiku modules into a pair of pure functions.

  See :func:`transform` for general details on Haiku transformations.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values and state::

      params, state = init(rng, *a, **k)
      out, state = apply(params, state, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for `apply` and
  passing ``None`` is accepted.

  This function is equivalent to :func:`transform`, however it allows you to
  maintain and update internal state (e.g. moving averages in batch norm) via
  :func:`get_state` and :func:`set_state`.

  >>> def f():
  ...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
  ...                          init=jnp.zeros)
  ...   hk.set_state("counter", counter + 1)
  ...   return counter

  >>> f = hk.transform_with_state(f)

  >>> params, state = f.init(None)
  >>> for _ in range(10):
  ...   counter, state = f.apply(params, state, None)
  >>> counter
  DeviceArray(9, dtype=int32)

  Args:
    f: A function closing over :class:`Module` instances.

  Returns:
    A :class:`TransformedWithState` tuple with `init` and `apply` properties.
  """
    analytics.log_once("transform_with_state")

    def init_fn(
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ) -> Tuple[Params, State]:
        """Initializes your function collecting parameters and state."""
        rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
        with base.new_context(rng=rng) as ctx:
            f(*args, **kwargs)
        return ctx.collect_params(), ctx.collect_initial_state()

    def apply_fn(
        params: Params,
        state: State,
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ) -> Tuple[Any, State]:
        """Applies your function injecting parameters and state."""
        params = check_mapping("params", params)
        state = check_mapping("state", state)
        rng = to_prng_sequence(
            rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
        with base.new_context(params=params, state=state, rng=rng) as ctx:
            out = f(*args, **kwargs)
        return out, ctx.collect_state()

    # EXPERIMENTAL: Expose the original function as a private attribute.
    init_fn._original_fn = f  # pylint: disable=protected-access
    apply_fn._original_fn = f  # pylint: disable=protected-access

    return TransformedWithState(init_fn, apply_fn)
Пример #4
0
def transform(f, *, apply_rng=False) -> Transformed:
    """Transforms a function using Haiku modules into a pair of pure functions.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values::

      params = init(rng, *a, **k)
      out = apply(params, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for `apply` and
  passing ``None`` is accepted.

  The first thing to do is to define a `Module`. A module encapsulates some
  parameters and a computation on those parameters:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     w = hk.get_parameter("w", [], init=jnp.zeros)
  ...     return x + w

  Next, define some function that creates and applies modules. We use
  :func:`transform` to transform that function into a pair of functions that
  allow us to lift all the parameters out of the function (``f.init``) and
  apply the function with a given set of parameters (``f.apply``):

  >>> def f(x):
  ...   a = MyModule()
  ...   b = MyModule()
  ...   return a(x) + b(x)

  >>> f = hk.transform(f)

  To get the initial state of the module call ``init`` with an example input:

  >>> params = f.init(None, 1)
  >>> params
  frozendict({
    'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}),
    'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}),
  })

  You can then apply the function with the given parameters by calling
  ``apply``:

  >>> f.apply(params, 1)
  DeviceArray(2., dtype=float32)

  It is expected that your program will at some point produce updated parameters
  and you will want to re-apply ``apply``. You can do this by calling ``apply``
  with different parameters:

  >>> new_params = {"my_module": {"w": jnp.array(2.)},
  ...               "my_module_1": {"w": jnp.array(3.)}}
  >>> f.apply(new_params, 2)
  DeviceArray(9., dtype=float32)

  If your transformed function needs to maintain internal state (e.g. moving
  averages in batch norm) then see :func:`transform_with_state`.

  Args:
    f: A function closing over :class:`Module` instances.
    apply_rng: Whether ``apply`` should accept `rng` as an argument.

  Returns:
    A :class:`Transformed` tuple with ``init`` and ``apply`` pure functions.
  """
    analytics.log_once("transform")

    if not apply_rng:
        warnings.warn("Apply_rng will soon be removed and defaulted to True",
                      DeprecationWarning)

    pair = transform_with_state(f)
    if not apply_rng:
        pair = without_apply_rng(pair)
    return without_state(pair)
Пример #5
0
def transform(f, *, apply_rng=True) -> Transformed:
    """Transforms a function using Haiku modules into a pair of pure functions.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values::

      params = init(rng, *a, **k)
      out = apply(params, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for ``apply`` and
  passing ``None`` is accepted.

  The first thing to do is to define a :class:`Module`. A module encapsulates
  some parameters and a computation on those parameters:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     w = hk.get_parameter("w", [], init=jnp.zeros)
  ...     return x + w

  Next, define some function that creates and applies modules. We use
  :func:`transform` to transform that function into a pair of functions that
  allow us to lift all the parameters out of the function (``f.init``) and
  apply the function with a given set of parameters (``f.apply``):

  >>> def f(x):
  ...   a = MyModule()
  ...   b = MyModule()
  ...   return a(x) + b(x)
  >>> f = hk.transform(f)

  To get the initial state of the module call ``init`` with an example input:

  >>> params = f.init(None, 1)
  >>> params
  frozendict({
    'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}),
    'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}),
  })

  You can then apply the function with the given parameters by calling
  ``apply`` (note that since we don't use Haiku's random number APIs to apply
  our network we pass ``None`` as an RNG key):

  >>> f.apply(params, None, 1)
  DeviceArray(2., dtype=float32)

  It is expected that your program will at some point produce updated parameters
  and you will want to re-apply ``apply``. You can do this by calling ``apply``
  with different parameters:

  >>> new_params = {"my_module": {"w": jnp.array(2.)},
  ...               "my_module_1": {"w": jnp.array(3.)}}
  >>> f.apply(new_params, None, 2)
  DeviceArray(9., dtype=float32)

  If your transformed function needs to maintain internal state (e.g. moving
  averages in batch norm) then see :func:`transform_with_state`.

  Args:
    f: A function closing over :class:`Module` instances.
    apply_rng: In the process of being removed. Can only value `True`.

  Returns:
    A :class:`Transformed` tuple with ``init`` and ``apply`` pure functions.
  """
    analytics.log_once("transform")

    if not apply_rng:
        raise ValueError(
            "The apply_rng argument has been removed and k.transform "
            "now *always* applies an rng.\n"
            "Replace hk.transform(..., apply_rng=False) with "
            "hk.without_apply_rng(hk.transform(...)).\n"
            "Replace hk.transform(..., apply_rng=True) with hk.transform(...)."
        )

    return without_state(transform_with_state(f))
Пример #6
0
def multi_transform_with_state(
    f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]],
) -> MultiTransformedWithState:
    """Transforms a collection of functions using Haiku into pure functions.

  See :func:`multi_transform` for more details.

  Example:

  >>> def f():
  ...   encoder = hk.Linear(1, name="encoder")
  ...   decoder = hk.Linear(1, name="decoder")
  ...
  ...   def init(x):
  ...     z = encoder(x)
  ...     return decoder(z)
  ...
  ...   return init, (encoder, decoder)

  >>> f = hk.multi_transform_with_state(f)
  >>> rng = jax.random.PRNGKey(42)
  >>> x = jnp.ones([1, 1])
  >>> params, state = f.init(rng, x)
  >>> jax.tree_map(jnp.shape, params)
  {'decoder': {'b': (1,), 'w': (1, 1)},
   'encoder': {'b': (1,), 'w': (1, 1)}}

  >>> encode, decode = f.apply
  >>> z, state = encode(params, state, None, x)
  >>> y, state = decode(params, state, None, z)

  Args:
    f: Function returning a "template" function and an arbitrary
      tree of functions using modules connected in the template function.

  Returns:
    An ``init`` function and a tree of pure ``apply`` functions.

  See also:
    :func:`transform_with_state`: Transform a single apply function.
    :func:`multi_transform`: Transform multiple apply functions without state.
  """
    analytics.log_once('multi_transform_with_state')

    def init_fn(*args, **kwargs):
        """Returns initial state for the transformed functions."""
        return f()[0](*args, **kwargs)

    init_fn = hk.transform_with_state(init_fn).init

    def apply_fn_i(i):
        def apply_fn(*args, **kwargs):
            """Applies the transformed function at the given inputs."""
            return jax.tree_leaves(f()[1])[i](*args, **kwargs)

        return apply_fn

    # We need to find out the structure of f()[1], including how many
    # functions there are, so that we can transform them individually and repack
    # into the same tree structure. It's valid for modules to declare parameters
    # in their constructor, so we need to create something that looks like
    # hk.Params in order to do this. `jax.eval_shape` interprets the function
    # abstractly, ie no real params are created, and we don't need to touch the
    # accelerator. This means hardcoding the RNG below is fine.
    def get_output_treedef() -> Box:
        rng = jax.random.PRNGKey(42)  # This is fine, see above
        fns = hk.transform_with_state(lambda: f()[1])
        apply_fns, _ = fns.apply(*fns.init(rng), rng)
        return Box(jax.tree_structure(apply_fns))

    output_treedef = jax.eval_shape(get_output_treedef).python_value
    apply_fns = make_tree(
        lambda i: hk.transform_with_state(apply_fn_i(i)).apply, output_treedef)

    return MultiTransformedWithState(init_fn, apply_fns)
Пример #7
0
def multi_transform(
    f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]], ) -> MultiTransformed:
    """Transforms a collection of functions using Haiku into pure functions.

  In many scenarios we have several modules which are used either as primitives
  for several Haiku modules/functions, or whose pure versions are to be reused
  in downstream code. This utility enables this by applying
  :func:`transform` to an arbitrary tree of Haiku functions which share modules
  and have a common ``init`` function.

  ``f`` is expected to return a tuple of two elements. First is a ``template``
  Haiku function which provides an example of how all internal Haiku modules are
  connected. This function is used to create a common ``init`` function (with
  your parameters).

  The second object is an arbitrary tree of Haiku functions all of which reuse
  the modules connected in the ``template`` function. These functions are
  transformed to pure ``apply`` functions.

  Example:

  >>> def f():
  ...   encoder = hk.Linear(1, name="encoder")
  ...   decoder = hk.Linear(1, name="decoder")
  ...
  ...   def init(x):
  ...     z = encoder(x)
  ...     return decoder(z)
  ...
  ...   return init, (encoder, decoder)

  >>> f = hk.multi_transform(f)
  >>> rng = jax.random.PRNGKey(42)
  >>> x = jnp.ones([1, 1])
  >>> params = f.init(rng, x)
  >>> jax.tree_map(jnp.shape, params)
  {'decoder': {'b': (1,), 'w': (1, 1)},
   'encoder': {'b': (1,), 'w': (1, 1)}}

  >>> encode, decode = f.apply
  >>> z = encode(params, None, x)
  >>> y = decode(params, None, z)

  Args:
    f: A factory function that returns two functions, firstly a common init
      function that creates all modules, and secondly a pytree of apply
      functions which make use of those modules.

  Returns:
    A :class:`MultiTransformed` instance which contains a pure init function
      that creates all parameters, and a pytree of pure apply functions that
      given the params apply the given function.

  See also:
    :func:`multi_transform_with_state`: Equivalent for modules using state.
  """
    analytics.log_once('multi_transform')

    f = multi_transform_with_state(f)
    f = without_state(f)
    return f
Пример #8
0
def transform(
    f,
    apply_rng=False,
    state=False,
) -> Transformed:
    """Transforms a function using Haiku modules into a pair of pure functions.

  The first thing to do is to define a `Module`. A module encapsulates some
  parameters and a computation on those parameters:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     w = hk.get_parameter("w", [], init=jnp.zeros)
  ...     return x + w

  Next, define some function that creates and applies modules. We use
  `hk.transform` to transform that function into a pair of functions that allow
  us to lift all the parameters out of the function (`f.init`) and apply the
  function with a given set of parameters (`f.apply`):

  >>> def f(x):
  ...   a = MyModule()
  ...   b = MyModule()
  ...   return a(x) + b(x)

  >>> f = hk.transform(f)

  To get the initial state of the module call the `init_fn` with an example
  input:

  >>> params = f.init(None, 1)
  >>> params
  frozendict({
    'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}),
    'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}),
  })

  You can then apply the function with the given parameters by calling
  `f.apply`:

  >>> f.apply(params, 1)
  DeviceArray(2., dtype=float32)

  It is expected that your program will at some point produce updated parameters
  and you will want to re-apply `f.apply`. You can do this by calling `f.apply`
  with different parameters:

  >>> new_params = {"my_module": {"w": jnp.array(2.)},
  ...               "my_module_1": {"w": jnp.array(3.)}}
  >>> f.apply(new_params, 2)
  DeviceArray(9., dtype=float32)

  If your transformed function needs to maintain internal state (e.g. moving
  averages in batch norm) then see :func:`transform_with_state`.

  Args:
    f: A function closing over `Module` instances.
    apply_rng: Whether `apply` should accept `rng` as an argument.
    state: *Deprecated:* use `hk.transform_with_state`.

  Returns:
    A named tuple with `init` and `apply` pure functions.
  """
    analytics.log_once("transform")

    if state:
        warnings.warn(
            "Prefer using hk.transform_with_state(f) vs. passing state=True.",
            DeprecationWarning)

    if apply_rng:
        warnings.warn("Apply_rng will soon be removed and defaulted to True",
                      DeprecationWarning)

    pair = transform_with_state(f)  # type: Transformed
    if not apply_rng:
        pair = without_apply_rng(pair)
    if not state:
        pair = without_state(pair)
    return pair
Пример #9
0
def transform_with_state(f) -> TransformedWithState:
    """Transforms a function using Haiku modules into a pair of pure functions.

  See :func:`transform` for general details on Haiku transformations.

  For a function ``out = f(*a, **k)`` this function returns a pair of two pure
  functions that call ``f(*a, **k)`` explicitly collecting and injecting
  parameter values and state::

      params, state = init(rng, *a, **k)
      out, state = apply(params, state, rng, *a, **k)

  Note that the ``rng`` argument is typically not required for ``apply`` and
  passing ``None`` is accepted.

  This function is equivalent to :func:`transform`, however it allows you to
  maintain and update internal state (e.g. :class:`ExponentialMovingAverage` in
  :class:`BatchNorm`) via :func:`get_state` and :func:`set_state`:

  >>> def f():
  ...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
  ...                          init=jnp.zeros)
  ...   hk.set_state("counter", counter + 1)
  ...   return counter
  >>> f = hk.transform_with_state(f)

  >>> params, state = f.init(None)
  >>> for _ in range(10):
  ...   counter, state = f.apply(params, state, None)
  >>> counter
  DeviceArray(9, dtype=int32)

  Args:
    f: A function closing over :class:`Module` instances.

  Returns:
    A :class:`TransformedWithState` tuple with ``init`` and ``apply`` pure
    functions.
  """
    analytics.log_once("transform_with_state")
    check_not_jax_transformed(f)

    unexpected_tracer_hint = (
        "An UnexpectedTracerError was raised while inside a Haiku transformed "
        "function (see error above).\n"
        "Hint: are you using a JAX transform or JAX control-flow function "
        "(jax.vmap/jax.scan/...) inside a Haiku transform? You might want to use "
        "the Haiku version of the transform instead (hk.vmap/hk.scan/...).\n"
        "See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html "
        "on why you can't use JAX transforms inside a Haiku module.")

    def init_fn(
        rng: Optional[Union[PRNGKey, int]],
        *args,
        **kwargs,
    ) -> Tuple[hk.Params, hk.State]:
        """Initializes your function collecting parameters and state."""
        rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
        with base.new_context(rng=rng) as ctx:
            try:
                f(*args, **kwargs)
            except jax.errors.UnexpectedTracerError as e:
                raise jax.errors.UnexpectedTracerError(
                    unexpected_tracer_hint) from e
        return ctx.collect_params(), ctx.collect_initial_state()

    def apply_fn(
        params: Optional[hk.Params],
        state: Optional[hk.State],
        rng: Optional[Union[PRNGKey, int]],
        *args,
        **kwargs,
    ) -> Tuple[Any, hk.State]:
        """Applies your function injecting parameters and state."""
        params = check_mapping("params", params)
        state = check_mapping("state", state)
        rng = to_prng_sequence(
            rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
        with base.new_context(params=params, state=state, rng=rng) as ctx:
            try:
                out = f(*args, **kwargs)
            except jax.errors.UnexpectedTracerError as e:
                raise jax.errors.UnexpectedTracerError(
                    unexpected_tracer_hint) from e
        return out, ctx.collect_state()

    tie_in_original_fn(f, init_fn, apply_fn)

    return TransformedWithState(init_fn, apply_fn)