def flax_module(name, nn_module, *, input_shape=None): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import flax # noqa: F401 except ImportError: raise ImportError("Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " "It can be installed with `pip install flax`.") module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_shape` needed to initialize.') # feed in dummy data to init params rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity()) _, nn_params = nn_module.init(rng_key, jnp.ones(input_shape)) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.call, nn_params)
def module(name, nn, input_shape=None): """ Declare a :mod:`~jax.experimental.stax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax` constructor function. :param tuple input_shape: shape of the input taken by the neural network. :return: a `apply_fn` with bound parameters that takes an array as an input and returns the neural network transformed output array. """ module_key = name + '$params' nn_init, nn_apply = nn nn_params = param(module_key) if nn_params is None: if input_shape is None: raise ValueError( 'Valid value for `input_size` needed to initialize.') rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity()) _, nn_params = nn_init(rng_key, input_shape) param(module_key, nn_params) return functools.partial(nn_apply, nn_params)
def haiku_module(name, nn, input_shape=None): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param haiku.Module nn: a `haiku` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku # noqa: F401 except ImportError: raise ImportError("Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install git+https://github.com/deepmind/dm-haiku`.") module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_shape` needed to initialize.') # feed in dummy data to init params rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity()) nn_params = nn.init(rng_key, jnp.ones(input_shape)) numpyro.param(module_key, nn_params) return partial(nn.apply, nn_params, None)