Example #1
0
def test_prng_key():
    assert numpyro.prng_key() is None

    with handlers.seed(rng_seed=0):
        rng_key = numpyro.prng_key()

    assert rng_key.shape == (2, ) and rng_key.dtype == "uint32"
Example #2
0
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        encoder = flax_module(
            "encoder",
            FlaxEncoder(
                hyperparams["vocab_size"],
                hyperparams["num_topics"],
                hyperparams["hidden"],
                hyperparams["dropout_rate"],
            ),
            input_shape=(1, hyperparams["vocab_size"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        encoder = haiku_module(
            "encoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuEncoder(
                    hyperparams["vocab_size"],
                    hyperparams["num_topics"],
                    hyperparams["hidden"],
                    hyperparams["dropout_rate"],
                )),
            input_shape=(1, hyperparams["vocab_size"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)

        if nn_framework == "flax":
            concentration = encoder(batch_docs,
                                    is_training,
                                    rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            concentration = encoder(numpyro.prng_key(), batch_docs,
                                    is_training)

        numpyro.sample("theta", dist.Dirichlet(concentration))
Example #3
0
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        decoder = flax_module(
            "decoder",
            FlaxDecoder(hyperparams["vocab_size"],
                        hyperparams["dropout_rate"]),
            input_shape=(1, hyperparams["num_topics"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        decoder = haiku_module(
            "decoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuDecoder(hyperparams["vocab_size"],
                             hyperparams["dropout_rate"])),
            input_shape=(1, hyperparams["num_topics"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)
        theta = numpyro.sample(
            "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"])))

        if nn_framework == "flax":
            logits = decoder(theta,
                             is_training,
                             rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            logits = decoder(numpyro.prng_key(), theta, is_training)

        total_count = batch_docs.sum(-1)
        numpyro.sample("obs",
                       dist.Multinomial(total_count, logits=logits),
                       obs=batch_docs)
Example #4
0
    def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.prng_key()
        with handlers.block():
            (
                init_params,
                _,
                self._postprocess_fn,
                self.prototype_trace,
            ) = initialize_model(
                rng_key,
                self.model,
                init_strategy=self.init_loc_fn,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs,
            )
        self._init_locs = init_params[0]

        self._prototype_frames = {}
        self._prototype_plate_sizes = {}
        for name, site in self.prototype_trace.items():
            if site["type"] == "sample":
                for frame in site["cond_indep_stack"]:
                    self._prototype_frames[frame.name] = frame
            elif site["type"] == "plate":
                self._prototype_frame_full_sizes[name] = site["args"][0]
Example #5
0
def flax_module(name, nn_module, *, input_shape=None):
    """
    Declare a :mod:`~flax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import flax  # noqa: F401
    except ImportError as e:
        raise ImportError("Looking like you want to use flax to declare "
                          "nn modules. This is an experimental feature. "
                          "You need to install `flax` to be able to use this feature. "
                          "It can be installed with `pip install flax`.") from e
    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        if input_shape is None:
            raise ValueError('Valid value for `input_shape` needed to initialize.')
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        _, nn_params = nn_module.init(rng_key, jnp.ones(input_shape))
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.call, nn_params)
Example #6
0
 def run(self, *args, rng_key=None, **kwargs):
     if rng_key is None:
         rng_key = numpyro.prng_key()
     self._mcmc.run(rng_key,
                    *args,
                    init_params=self._initial_params,
                    **kwargs)
Example #7
0
 def model():
     transform = hk.transform_with_state if batchnorm else hk.transform
     nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3))
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = nn(numpyro.prng_key(), x)
     else:
         y = nn(x)
     numpyro.deterministic("y", y)
Example #8
0
 def model():
     net = flax_module(
         "nn",
         Net(),
         apply_rng=["dropout"] if dropout else None,
         mutable=["batch_stats"] if batchnorm else None,
         input_shape=(4, 3),
     )
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = net(x, rngs={"dropout": numpyro.prng_key()})
     else:
         y = net(x)
     numpyro.deterministic("y", y)
Example #9
0
 def step(self, *args, rng_key=None, **kwargs):
     if self.svi_state is None:
         if rng_key is None:
             rng_key = numpyro.prng_key()
         self.svi_state = self.init(rng_key, *args, **kwargs)
     try:
         self.svi_state, loss = jit(self.update)(self.svi_state, *args,
                                                 **kwargs)
     except TypeError as e:
         if "not a valid JAX type" in str(e):
             raise TypeError(
                 "NumPyro backend requires args, kwargs to be arrays or tuples, "
                 "dicts of arrays.") from e
         else:
             raise e
     params = jit(super(SVI, self).get_params)(self.svi_state)
     get_param_store().update(params)
     return loss
Example #10
0
def haiku_module(name, nn_module, *, input_shape=None, **kwargs):
    """
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import haiku  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use haiku to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `haiku` to be able to use this feature. "
            "It can be installed with `pip install dm-haiku`.") from e

    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        args = (jnp.ones(input_shape), ) if input_shape is not None else ()
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        nn_params = nn_module.init(rng_key, *args, **kwargs)
        # haiku init returns an immutable dict
        nn_params = haiku.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.apply, nn_params, None)
Example #11
0
 def model(x=None):
     return numpyro.prng_key()
Example #12
0
def flax_module(
    name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
    """
    Declare a :mod:`~flax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    Given a flax ``nn_module``, in flax to evaluate the module with
    a given set of parameters, we use: ``nn_module.apply(params, x)``.
    In a NumPyro model, the pattern will be::

        net = flax_module("net", nn_module)
        y = net(x)

    or with dropout layers::

        net = flax_module("net", nn_module, apply_rng=["dropout"])
        rng_key = numpyro.prng_key()
        y = net(x, rngs={"dropout": rng_key})

    :param str name: name of the module to be registered.
    :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
        ``nn_module``. For example, when ``nn_module`` includes dropout layers, we
        need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
        rng key is needed. Please see
        `Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
        for more information in how Flax deals with stochastic layers like dropout.
    :param list mutable: A list to indicate mutable states of ``nn_module``. For example,
        if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
        See the above `Flax Linen Intro` tutorial for more information.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import flax  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use flax to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `flax` to be able to use this feature. "
            "It can be installed with `pip install flax`."
        ) from e
    module_key = name + "$params"
    nn_params = numpyro.param(module_key)

    if mutable:
        nn_state = numpyro_mutable(name + "$state")
        assert nn_state is None or isinstance(nn_state, dict)
        assert (nn_state is None) == (nn_params is None)

    if nn_params is None:
        # feed in dummy data to init params
        args = (jnp.ones(input_shape),) if input_shape is not None else ()
        rng_key = numpyro.prng_key()
        # split rng_key into a dict of rng_kind: rng_key
        rngs = {}
        if apply_rng:
            assert isinstance(apply_rng, list)
            for kind in apply_rng:
                rng_key, subkey = random.split(rng_key)
                rngs[kind] = subkey
        rngs["params"] = rng_key

        nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
        if "params" not in nn_vars:
            raise ValueError(
                "Your nn_module does not have any parameter. Currently, it is not"
                " supported in NumPyro. Please make a github issue if you need"
                " that feature."
            )
        nn_params = nn_vars["params"]
        if mutable:
            nn_state = {k: v for k, v in nn_vars.items() if k != "params"}
            assert set(mutable) == set(nn_state)
            numpyro_mutable(name + "$state", nn_state)
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)

    def apply_with_state(params, *args, **kwargs):
        params = {"params": params, **nn_state}
        out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
        nn_state.update(**new_state)
        return out

    def apply_without_state(params, *args, **kwargs):
        return nn_module.apply({"params": params}, *args, **kwargs)

    apply_fn = apply_with_state if mutable else apply_without_state
    return partial(apply_fn, nn_params)
Example #13
0
def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs):
    """
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    Given a haiku ``nn_module``, in haiku to evaluate the module with
    a given set of parameters, we use: ``nn_module.apply(params, None, x)``.
    In a NumPyro model, the pattern will be::

        net = haiku_module("net", nn_module)
        y = net(x)  # or y = net(rng_key, x)

    or with dropout layers::

        net = haiku_module("net", nn_module, apply_rng=True)
        rng_key = numpyro.prng_key()
        y = net(rng_key, x)

    :param str name: name of the module to be registered.
    :param nn_module: a `haiku` Module which has .init and .apply methods
    :type nn_module: haiku.Transformed or haiku.TransformedWithState
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param bool apply_rng: A flag to indicate if the returned callable requires
        an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults
        to False, which means no rng argument is needed. If this is True, the signature
        of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be
        ``nn(rng_key, x)`` (rather than ``nn(x)``).
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import haiku as hk  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use haiku to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `haiku` to be able to use this feature. "
            "It can be installed with `pip install dm-haiku`."
        ) from e

    if not apply_rng:
        nn_module = hk.without_apply_rng(nn_module)

    module_key = name + "$params"
    nn_params = numpyro.param(module_key)
    with_state = isinstance(nn_module, hk.TransformedWithState)
    if with_state:
        nn_state = numpyro_mutable(name + "$state")
        assert nn_state is None or isinstance(nn_state, dict)
        assert (nn_state is None) == (nn_params is None)

    if nn_params is None:
        args = (jnp.ones(input_shape),) if input_shape is not None else ()
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        if with_state:
            nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs)
            nn_state = dict(nn_state)
            numpyro_mutable(name + "$state", nn_state)
        else:
            nn_params = nn_module.init(rng_key, *args, **kwargs)
        # haiku init returns an immutable dict
        nn_params = hk.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)

    def apply_with_state(params, *args, **kwargs):
        out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
        nn_state.update(**new_state)
        return out

    apply_fn = apply_with_state if with_state else nn_module.apply
    return partial(apply_fn, nn_params)