def transform_with_state(f) -> TransformedWithState: """Transforms a function using Haiku modules into a pair of pure functions. See :func:`transform` for general details on Haiku transformations. This function is equivalent to :func:`transform`, however it allows you to maintain and update internal state (e.g. moving averages in batch norm) via :func:`get_state` and :func:`set_state`. >>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f) >>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> counter DeviceArray(9, dtype=int32) Args: f: A function closing over `Module` instances. Returns: A named tuple with `init` and `apply` properties. """ analytics.log_once("transform_with_state") return TransformedWithState(mk_init_fn(f), mk_apply_fn(f))
def transform_with_state(f) -> TransformedWithState: """Transforms a function using Haiku modules into a pair of pure functions. See :func:`transform` for general details on Haiku transformations. For a function ``out = f(*a, **k)`` this function returns a pair of two pure functions that call ``f(*a, **k)`` explicitly collecting and injecting parameter values and state:: params, state = init(rng, *a, **k) out, state = apply(params, state, rng, *a, **k) Note that the ``rng`` argument is typically not required for `apply` and passing ``None`` is accpeted. This function is equivalent to :func:`transform`, however it allows you to maintain and update internal state (e.g. moving averages in batch norm) via :func:`get_state` and :func:`set_state`. >>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f) >>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> counter DeviceArray(9, dtype=int32) Args: f: A function closing over :class:`Module` instances. Returns: A :class:`TransformedWithState` tuple with `init` and `apply` properties. """ analytics.log_once("transform_with_state") return TransformedWithState(make_init_fn(f), make_apply_fn(f))
def transform_with_state(f) -> TransformedWithState: """Transforms a function using Haiku modules into a pair of pure functions. See :func:`transform` for general details on Haiku transformations. For a function ``out = f(*a, **k)`` this function returns a pair of two pure functions that call ``f(*a, **k)`` explicitly collecting and injecting parameter values and state:: params, state = init(rng, *a, **k) out, state = apply(params, state, rng, *a, **k) Note that the ``rng`` argument is typically not required for `apply` and passing ``None`` is accepted. This function is equivalent to :func:`transform`, however it allows you to maintain and update internal state (e.g. moving averages in batch norm) via :func:`get_state` and :func:`set_state`. >>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f) >>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> counter DeviceArray(9, dtype=int32) Args: f: A function closing over :class:`Module` instances. Returns: A :class:`TransformedWithState` tuple with `init` and `apply` properties. """ analytics.log_once("transform_with_state") def init_fn( rng: Optional[Union[PRNGKey, PRNGSeed]], *args, **kwargs, ) -> Tuple[Params, State]: """Initializes your function collecting parameters and state.""" rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with base.new_context(rng=rng) as ctx: f(*args, **kwargs) return ctx.collect_params(), ctx.collect_initial_state() def apply_fn( params: Params, state: State, rng: Optional[Union[PRNGKey, PRNGSeed]], *args, **kwargs, ) -> Tuple[Any, State]: """Applies your function injecting parameters and state.""" params = check_mapping("params", params) state = check_mapping("state", state) rng = to_prng_sequence( rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)) with base.new_context(params=params, state=state, rng=rng) as ctx: out = f(*args, **kwargs) return out, ctx.collect_state() # EXPERIMENTAL: Expose the original function as a private attribute. init_fn._original_fn = f # pylint: disable=protected-access apply_fn._original_fn = f # pylint: disable=protected-access return TransformedWithState(init_fn, apply_fn)
def transform(f, *, apply_rng=False) -> Transformed: """Transforms a function using Haiku modules into a pair of pure functions. For a function ``out = f(*a, **k)`` this function returns a pair of two pure functions that call ``f(*a, **k)`` explicitly collecting and injecting parameter values:: params = init(rng, *a, **k) out = apply(params, rng, *a, **k) Note that the ``rng`` argument is typically not required for `apply` and passing ``None`` is accepted. The first thing to do is to define a `Module`. A module encapsulates some parameters and a computation on those parameters: >>> class MyModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.zeros) ... return x + w Next, define some function that creates and applies modules. We use :func:`transform` to transform that function into a pair of functions that allow us to lift all the parameters out of the function (``f.init``) and apply the function with a given set of parameters (``f.apply``): >>> def f(x): ... a = MyModule() ... b = MyModule() ... return a(x) + b(x) >>> f = hk.transform(f) To get the initial state of the module call ``init`` with an example input: >>> params = f.init(None, 1) >>> params frozendict({ 'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}), 'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}), }) You can then apply the function with the given parameters by calling ``apply``: >>> f.apply(params, 1) DeviceArray(2., dtype=float32) It is expected that your program will at some point produce updated parameters and you will want to re-apply ``apply``. You can do this by calling ``apply`` with different parameters: >>> new_params = {"my_module": {"w": jnp.array(2.)}, ... "my_module_1": {"w": jnp.array(3.)}} >>> f.apply(new_params, 2) DeviceArray(9., dtype=float32) If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see :func:`transform_with_state`. Args: f: A function closing over :class:`Module` instances. apply_rng: Whether ``apply`` should accept `rng` as an argument. Returns: A :class:`Transformed` tuple with ``init`` and ``apply`` pure functions. """ analytics.log_once("transform") if not apply_rng: warnings.warn("Apply_rng will soon be removed and defaulted to True", DeprecationWarning) pair = transform_with_state(f) if not apply_rng: pair = without_apply_rng(pair) return without_state(pair)
def transform(f, *, apply_rng=True) -> Transformed: """Transforms a function using Haiku modules into a pair of pure functions. For a function ``out = f(*a, **k)`` this function returns a pair of two pure functions that call ``f(*a, **k)`` explicitly collecting and injecting parameter values:: params = init(rng, *a, **k) out = apply(params, rng, *a, **k) Note that the ``rng`` argument is typically not required for ``apply`` and passing ``None`` is accepted. The first thing to do is to define a :class:`Module`. A module encapsulates some parameters and a computation on those parameters: >>> class MyModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.zeros) ... return x + w Next, define some function that creates and applies modules. We use :func:`transform` to transform that function into a pair of functions that allow us to lift all the parameters out of the function (``f.init``) and apply the function with a given set of parameters (``f.apply``): >>> def f(x): ... a = MyModule() ... b = MyModule() ... return a(x) + b(x) >>> f = hk.transform(f) To get the initial state of the module call ``init`` with an example input: >>> params = f.init(None, 1) >>> params frozendict({ 'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}), 'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}), }) You can then apply the function with the given parameters by calling ``apply`` (note that since we don't use Haiku's random number APIs to apply our network we pass ``None`` as an RNG key): >>> f.apply(params, None, 1) DeviceArray(2., dtype=float32) It is expected that your program will at some point produce updated parameters and you will want to re-apply ``apply``. You can do this by calling ``apply`` with different parameters: >>> new_params = {"my_module": {"w": jnp.array(2.)}, ... "my_module_1": {"w": jnp.array(3.)}} >>> f.apply(new_params, None, 2) DeviceArray(9., dtype=float32) If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see :func:`transform_with_state`. Args: f: A function closing over :class:`Module` instances. apply_rng: In the process of being removed. Can only value `True`. Returns: A :class:`Transformed` tuple with ``init`` and ``apply`` pure functions. """ analytics.log_once("transform") if not apply_rng: raise ValueError( "The apply_rng argument has been removed and k.transform " "now *always* applies an rng.\n" "Replace hk.transform(..., apply_rng=False) with " "hk.without_apply_rng(hk.transform(...)).\n" "Replace hk.transform(..., apply_rng=True) with hk.transform(...)." ) return without_state(transform_with_state(f))
def multi_transform_with_state( f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]], ) -> MultiTransformedWithState: """Transforms a collection of functions using Haiku into pure functions. See :func:`multi_transform` for more details. Example: >>> def f(): ... encoder = hk.Linear(1, name="encoder") ... decoder = hk.Linear(1, name="decoder") ... ... def init(x): ... z = encoder(x) ... return decoder(z) ... ... return init, (encoder, decoder) >>> f = hk.multi_transform_with_state(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params, state = f.init(rng, x) >>> jax.tree_map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}} >>> encode, decode = f.apply >>> z, state = encode(params, state, None, x) >>> y, state = decode(params, state, None, z) Args: f: Function returning a "template" function and an arbitrary tree of functions using modules connected in the template function. Returns: An ``init`` function and a tree of pure ``apply`` functions. See also: :func:`transform_with_state`: Transform a single apply function. :func:`multi_transform`: Transform multiple apply functions without state. """ analytics.log_once('multi_transform_with_state') def init_fn(*args, **kwargs): """Returns initial state for the transformed functions.""" return f()[0](*args, **kwargs) init_fn = hk.transform_with_state(init_fn).init def apply_fn_i(i): def apply_fn(*args, **kwargs): """Applies the transformed function at the given inputs.""" return jax.tree_leaves(f()[1])[i](*args, **kwargs) return apply_fn # We need to find out the structure of f()[1], including how many # functions there are, so that we can transform them individually and repack # into the same tree structure. It's valid for modules to declare parameters # in their constructor, so we need to create something that looks like # hk.Params in order to do this. `jax.eval_shape` interprets the function # abstractly, ie no real params are created, and we don't need to touch the # accelerator. This means hardcoding the RNG below is fine. def get_output_treedef() -> Box: rng = jax.random.PRNGKey(42) # This is fine, see above fns = hk.transform_with_state(lambda: f()[1]) apply_fns, _ = fns.apply(*fns.init(rng), rng) return Box(jax.tree_structure(apply_fns)) output_treedef = jax.eval_shape(get_output_treedef).python_value apply_fns = make_tree( lambda i: hk.transform_with_state(apply_fn_i(i)).apply, output_treedef) return MultiTransformedWithState(init_fn, apply_fns)
def multi_transform( f: Callable[[], Tuple[TemplateFn, TreeOfApplyFns]], ) -> MultiTransformed: """Transforms a collection of functions using Haiku into pure functions. In many scenarios we have several modules which are used either as primitives for several Haiku modules/functions, or whose pure versions are to be reused in downstream code. This utility enables this by applying :func:`transform` to an arbitrary tree of Haiku functions which share modules and have a common ``init`` function. ``f`` is expected to return a tuple of two elements. First is a ``template`` Haiku function which provides an example of how all internal Haiku modules are connected. This function is used to create a common ``init`` function (with your parameters). The second object is an arbitrary tree of Haiku functions all of which reuse the modules connected in the ``template`` function. These functions are transformed to pure ``apply`` functions. Example: >>> def f(): ... encoder = hk.Linear(1, name="encoder") ... decoder = hk.Linear(1, name="decoder") ... ... def init(x): ... z = encoder(x) ... return decoder(z) ... ... return init, (encoder, decoder) >>> f = hk.multi_transform(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params = f.init(rng, x) >>> jax.tree_map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}} >>> encode, decode = f.apply >>> z = encode(params, None, x) >>> y = decode(params, None, z) Args: f: A factory function that returns two functions, firstly a common init function that creates all modules, and secondly a pytree of apply functions which make use of those modules. Returns: A :class:`MultiTransformed` instance which contains a pure init function that creates all parameters, and a pytree of pure apply functions that given the params apply the given function. See also: :func:`multi_transform_with_state`: Equivalent for modules using state. """ analytics.log_once('multi_transform') f = multi_transform_with_state(f) f = without_state(f) return f
def transform( f, apply_rng=False, state=False, ) -> Transformed: """Transforms a function using Haiku modules into a pair of pure functions. The first thing to do is to define a `Module`. A module encapsulates some parameters and a computation on those parameters: >>> class MyModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.zeros) ... return x + w Next, define some function that creates and applies modules. We use `hk.transform` to transform that function into a pair of functions that allow us to lift all the parameters out of the function (`f.init`) and apply the function with a given set of parameters (`f.apply`): >>> def f(x): ... a = MyModule() ... b = MyModule() ... return a(x) + b(x) >>> f = hk.transform(f) To get the initial state of the module call the `init_fn` with an example input: >>> params = f.init(None, 1) >>> params frozendict({ 'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}), 'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}), }) You can then apply the function with the given parameters by calling `f.apply`: >>> f.apply(params, 1) DeviceArray(2., dtype=float32) It is expected that your program will at some point produce updated parameters and you will want to re-apply `f.apply`. You can do this by calling `f.apply` with different parameters: >>> new_params = {"my_module": {"w": jnp.array(2.)}, ... "my_module_1": {"w": jnp.array(3.)}} >>> f.apply(new_params, 2) DeviceArray(9., dtype=float32) If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see :func:`transform_with_state`. Args: f: A function closing over `Module` instances. apply_rng: Whether `apply` should accept `rng` as an argument. state: *Deprecated:* use `hk.transform_with_state`. Returns: A named tuple with `init` and `apply` pure functions. """ analytics.log_once("transform") if state: warnings.warn( "Prefer using hk.transform_with_state(f) vs. passing state=True.", DeprecationWarning) if apply_rng: warnings.warn("Apply_rng will soon be removed and defaulted to True", DeprecationWarning) pair = transform_with_state(f) # type: Transformed if not apply_rng: pair = without_apply_rng(pair) if not state: pair = without_state(pair) return pair
def transform_with_state(f) -> TransformedWithState: """Transforms a function using Haiku modules into a pair of pure functions. See :func:`transform` for general details on Haiku transformations. For a function ``out = f(*a, **k)`` this function returns a pair of two pure functions that call ``f(*a, **k)`` explicitly collecting and injecting parameter values and state:: params, state = init(rng, *a, **k) out, state = apply(params, state, rng, *a, **k) Note that the ``rng`` argument is typically not required for ``apply`` and passing ``None`` is accepted. This function is equivalent to :func:`transform`, however it allows you to maintain and update internal state (e.g. :class:`ExponentialMovingAverage` in :class:`BatchNorm`) via :func:`get_state` and :func:`set_state`: >>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f) >>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> counter DeviceArray(9, dtype=int32) Args: f: A function closing over :class:`Module` instances. Returns: A :class:`TransformedWithState` tuple with ``init`` and ``apply`` pure functions. """ analytics.log_once("transform_with_state") check_not_jax_transformed(f) unexpected_tracer_hint = ( "An UnexpectedTracerError was raised while inside a Haiku transformed " "function (see error above).\n" "Hint: are you using a JAX transform or JAX control-flow function " "(jax.vmap/jax.scan/...) inside a Haiku transform? You might want to use " "the Haiku version of the transform instead (hk.vmap/hk.scan/...).\n" "See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html " "on why you can't use JAX transforms inside a Haiku module.") def init_fn( rng: Optional[Union[PRNGKey, int]], *args, **kwargs, ) -> Tuple[hk.Params, hk.State]: """Initializes your function collecting parameters and state.""" rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with base.new_context(rng=rng) as ctx: try: f(*args, **kwargs) except jax.errors.UnexpectedTracerError as e: raise jax.errors.UnexpectedTracerError( unexpected_tracer_hint) from e return ctx.collect_params(), ctx.collect_initial_state() def apply_fn( params: Optional[hk.Params], state: Optional[hk.State], rng: Optional[Union[PRNGKey, int]], *args, **kwargs, ) -> Tuple[Any, hk.State]: """Applies your function injecting parameters and state.""" params = check_mapping("params", params) state = check_mapping("state", state) rng = to_prng_sequence( rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)) with base.new_context(params=params, state=state, rng=rng) as ctx: try: out = f(*args, **kwargs) except jax.errors.UnexpectedTracerError as e: raise jax.errors.UnexpectedTracerError( unexpected_tracer_hint) from e return out, ctx.collect_state() tie_in_original_fn(f, init_fn, apply_fn) return TransformedWithState(init_fn, apply_fn)