def test_prng_key(): assert numpyro.prng_key() is None with handlers.seed(rng_seed=0): rng_key = numpyro.prng_key() assert rng_key.shape == (2, ) and rng_key.dtype == "uint32"
def guide(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": encoder = flax_module( "encoder", FlaxEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], ), input_shape=(1, hyperparams["vocab_size"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": encoder = haiku_module( "encoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], )), input_shape=(1, hyperparams["vocab_size"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) if nn_framework == "flax": concentration = encoder(batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": concentration = encoder(numpyro.prng_key(), batch_docs, is_training) numpyro.sample("theta", dist.Dirichlet(concentration))
def model(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": decoder = flax_module( "decoder", FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), input_shape=(1, hyperparams["num_topics"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": decoder = haiku_module( "decoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])), input_shape=(1, hyperparams["num_topics"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) theta = numpyro.sample( "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))) if nn_framework == "flax": logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": logits = decoder(numpyro.prng_key(), theta, is_training) total_count = batch_docs.sum(-1) numpyro.sample("obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs)
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.prng_key() with handlers.block(): ( init_params, _, self._postprocess_fn, self.prototype_trace, ) = initialize_model( rng_key, self.model, init_strategy=self.init_loc_fn, dynamic_args=False, model_args=args, model_kwargs=kwargs, ) self._init_locs = init_params[0] self._prototype_frames = {} self._prototype_plate_sizes = {} for name, site in self.prototype_trace.items(): if site["type"] == "sample": for frame in site["cond_indep_stack"]: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": self._prototype_frame_full_sizes[name] = site["args"][0]
def flax_module(name, nn_module, *, input_shape=None): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import flax # noqa: F401 except ImportError as e: raise ImportError("Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " "It can be installed with `pip install flax`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_shape` needed to initialize.') # feed in dummy data to init params rng_key = numpyro.prng_key() _, nn_params = nn_module.init(rng_key, jnp.ones(input_shape)) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.call, nn_params)
def run(self, *args, rng_key=None, **kwargs): if rng_key is None: rng_key = numpyro.prng_key() self._mcmc.run(rng_key, *args, init_params=self._initial_params, **kwargs)
def model(): transform = hk.transform_with_state if batchnorm else hk.transform nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = nn(numpyro.prng_key(), x) else: y = nn(x) numpyro.deterministic("y", y)
def model(): net = flax_module( "nn", Net(), apply_rng=["dropout"] if dropout else None, mutable=["batch_stats"] if batchnorm else None, input_shape=(4, 3), ) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = net(x, rngs={"dropout": numpyro.prng_key()}) else: y = net(x) numpyro.deterministic("y", y)
def step(self, *args, rng_key=None, **kwargs): if self.svi_state is None: if rng_key is None: rng_key = numpyro.prng_key() self.svi_state = self.init(rng_key, *args, **kwargs) try: self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs) except TypeError as e: if "not a valid JAX type" in str(e): raise TypeError( "NumPyro backend requires args, kwargs to be arrays or tuples, " "dicts of arrays.") from e else: raise e params = jit(super(SVI, self).get_params)(self.svi_state) get_param_store().update(params) return loss
def haiku_module(name, nn_module, *, input_shape=None, **kwargs): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install dm-haiku`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: args = (jnp.ones(input_shape), ) if input_shape is not None else () # feed in dummy data to init params rng_key = numpyro.prng_key() nn_params = nn_module.init(rng_key, *args, **kwargs) # haiku init returns an immutable dict nn_params = haiku.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.apply, nn_params, None)
def model(x=None): return numpyro.prng_key()
def flax_module( name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs ): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. Given a flax ``nn_module``, in flax to evaluate the module with a given set of parameters, we use: ``nn_module.apply(params, x)``. In a NumPyro model, the pattern will be:: net = flax_module("net", nn_module) y = net(x) or with dropout layers:: net = flax_module("net", nn_module, apply_rng=["dropout"]) rng_key = numpyro.prng_key() y = net(x, rngs={"dropout": rng_key}) :param str name: name of the module to be registered. :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for ``nn_module``. For example, when ``nn_module`` includes dropout layers, we need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra rng key is needed. Please see `Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_ for more information in how Flax deals with stochastic layers like dropout. :param list mutable: A list to indicate mutable states of ``nn_module``. For example, if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``. See the above `Flax Linen Intro` tutorial for more information. :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import flax # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " "It can be installed with `pip install flax`." ) from e module_key = name + "$params" nn_params = numpyro.param(module_key) if mutable: nn_state = numpyro_mutable(name + "$state") assert nn_state is None or isinstance(nn_state, dict) assert (nn_state is None) == (nn_params is None) if nn_params is None: # feed in dummy data to init params args = (jnp.ones(input_shape),) if input_shape is not None else () rng_key = numpyro.prng_key() # split rng_key into a dict of rng_kind: rng_key rngs = {} if apply_rng: assert isinstance(apply_rng, list) for kind in apply_rng: rng_key, subkey = random.split(rng_key) rngs[kind] = subkey rngs["params"] = rng_key nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs)) if "params" not in nn_vars: raise ValueError( "Your nn_module does not have any parameter. Currently, it is not" " supported in NumPyro. Please make a github issue if you need" " that feature." ) nn_params = nn_vars["params"] if mutable: nn_state = {k: v for k, v in nn_vars.items() if k != "params"} assert set(mutable) == set(nn_state) numpyro_mutable(name + "$state", nn_state) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): params = {"params": params, **nn_state} out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs) nn_state.update(**new_state) return out def apply_without_state(params, *args, **kwargs): return nn_module.apply({"params": params}, *args, **kwargs) apply_fn = apply_with_state if mutable else apply_without_state return partial(apply_fn, nn_params)
def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. Given a haiku ``nn_module``, in haiku to evaluate the module with a given set of parameters, we use: ``nn_module.apply(params, None, x)``. In a NumPyro model, the pattern will be:: net = haiku_module("net", nn_module) y = net(x) # or y = net(rng_key, x) or with dropout layers:: net = haiku_module("net", nn_module, apply_rng=True) rng_key = numpyro.prng_key() y = net(rng_key, x) :param str name: name of the module to be registered. :param nn_module: a `haiku` Module which has .init and .apply methods :type nn_module: haiku.Transformed or haiku.TransformedWithState :param tuple input_shape: shape of the input taken by the neural network. :param bool apply_rng: A flag to indicate if the returned callable requires an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be ``nn(rng_key, x)`` (rather than ``nn(x)``). :param kwargs: optional keyword arguments to initialize flax neural network as an alternative to `input_shape` :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku as hk # noqa: F401 except ImportError as e: raise ImportError( "Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install dm-haiku`." ) from e if not apply_rng: nn_module = hk.without_apply_rng(nn_module) module_key = name + "$params" nn_params = numpyro.param(module_key) with_state = isinstance(nn_module, hk.TransformedWithState) if with_state: nn_state = numpyro_mutable(name + "$state") assert nn_state is None or isinstance(nn_state, dict) assert (nn_state is None) == (nn_params is None) if nn_params is None: args = (jnp.ones(input_shape),) if input_shape is not None else () # feed in dummy data to init params rng_key = numpyro.prng_key() if with_state: nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs) nn_state = dict(nn_state) numpyro_mutable(name + "$state", nn_state) else: nn_params = nn_module.init(rng_key, *args, **kwargs) # haiku init returns an immutable dict nn_params = hk.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): out, new_state = nn_module.apply(params, nn_state, *args, **kwargs) nn_state.update(**new_state) return out apply_fn = apply_with_state if with_state else nn_module.apply return partial(apply_fn, nn_params)