def guide(): loc = numpyro.param("loc", 0.0) p = numpyro_mutable("loc1p", {"value": None}) # we can modify the content of `p` if it is a dict p["value"] = loc + 2 numpyro.sample("x", dist.Normal(loc, 0.1))
def flax_module( name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs ): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. Given a flax ``nn_module``, in flax to evaluate the module with a given set of parameters, we use: ``nn_module.apply(params, x)``. In a NumPyro model, the pattern will be:: net = flax_module("net", nn_module) y = net(x) or with dropout layers:: net = flax_module("net", nn_module, apply_rng=["dropout"]) rng_key = numpyro.prng_key() y = net(x, rngs={"dropout": rng_key}) :param str name: name of the module to be registered. :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for ``nn_module``. For example, when ``nn_module`` includes dropout layers, we need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra rng key is needed. Please see `Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_ for more information in how Flax deals with stochastic layers like dropout. :param list mutable: A list to indicate mutable states of ``nn_module``. For example, if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``. See the above `Flax Linen Intro` tutorial for more information. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import flax # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " "It can be installed with `pip install flax`." ) from e module_key = name + "$params" nn_params = numpyro.param(module_key) if mutable: nn_state = numpyro_mutable(name + "$state") assert nn_state is None or isinstance(nn_state, dict) assert (nn_state is None) == (nn_params is None) if nn_params is None: # feed in dummy data to init params args = (jnp.ones(input_shape),) if input_shape is not None else () rng_key = numpyro.prng_key() # split rng_key into a dict of rng_kind: rng_key rngs = {} if apply_rng: assert isinstance(apply_rng, list) for kind in apply_rng: rng_key, subkey = random.split(rng_key) rngs[kind] = subkey rngs["params"] = rng_key nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs)) if "params" not in nn_vars: raise ValueError( "Your nn_module does not have any parameter. Currently, it is not" " supported in NumPyro. Please make a github issue if you need" " that feature." ) nn_params = nn_vars["params"] if mutable: nn_state = {k: v for k, v in nn_vars.items() if k != "params"} assert set(mutable) == set(nn_state) numpyro_mutable(name + "$state", nn_state) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): params = {"params": params, **nn_state} out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs) nn_state.update(**new_state) return out def apply_without_state(params, *args, **kwargs): return nn_module.apply({"params": params}, *args, **kwargs) apply_fn = apply_with_state if mutable else apply_without_state return partial(apply_fn, nn_params)
def model(): x = numpyro.sample("x", dist.Normal(-1, 1)) numpyro_mutable("x1p", x + 1)
def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. Given a haiku ``nn_module``, in haiku to evaluate the module with a given set of parameters, we use: ``nn_module.apply(params, None, x)``. In a NumPyro model, the pattern will be:: net = haiku_module("net", nn_module) y = net(x) # or y = net(rng_key, x) or with dropout layers:: net = haiku_module("net", nn_module, apply_rng=True) rng_key = numpyro.prng_key() y = net(rng_key, x) :param str name: name of the module to be registered. :param nn_module: a `haiku` Module which has .init and .apply methods :type nn_module: haiku.Transformed or haiku.TransformedWithState :param tuple input_shape: shape of the input taken by the neural network. :param bool apply_rng: A flag to indicate if the returned callable requires an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be ``nn(rng_key, x)`` (rather than ``nn(x)``). :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku as hk # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install dm-haiku`." ) from e if not apply_rng: nn_module = hk.without_apply_rng(nn_module) module_key = name + "$params" nn_params = numpyro.param(module_key) with_state = isinstance(nn_module, hk.TransformedWithState) if with_state: nn_state = numpyro_mutable(name + "$state") assert nn_state is None or isinstance(nn_state, dict) assert (nn_state is None) == (nn_params is None) if nn_params is None: args = (jnp.ones(input_shape),) if input_shape is not None else () # feed in dummy data to init params rng_key = numpyro.prng_key() if with_state: nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs) nn_state = dict(nn_state) numpyro_mutable(name + "$state", nn_state) else: nn_params = nn_module.init(rng_key, *args, **kwargs) # haiku init returns an immutable dict nn_params = hk.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): out, new_state = nn_module.apply(params, nn_state, *args, **kwargs) nn_state.update(**new_state) return out apply_fn = apply_with_state if with_state else nn_module.apply return partial(apply_fn, nn_params)