Esempio n. 1
0
 def guide():
     loc = numpyro.param("loc", 0.0)
     p = numpyro_mutable("loc1p", {"value": None})
     # we can modify the content of `p` if it is a dict
     p["value"] = loc + 2
     numpyro.sample("x", dist.Normal(loc, 0.1))
Esempio n. 2
0
def flax_module(
    name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
    """
    Declare a :mod:`~flax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    Given a flax ``nn_module``, in flax to evaluate the module with
    a given set of parameters, we use: ``nn_module.apply(params, x)``.
    In a NumPyro model, the pattern will be::

        net = flax_module("net", nn_module)
        y = net(x)

    or with dropout layers::

        net = flax_module("net", nn_module, apply_rng=["dropout"])
        rng_key = numpyro.prng_key()
        y = net(x, rngs={"dropout": rng_key})

    :param str name: name of the module to be registered.
    :param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
        ``nn_module``. For example, when ``nn_module`` includes dropout layers, we
        need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
        rng key is needed. Please see
        `Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
        for more information in how Flax deals with stochastic layers like dropout.
    :param list mutable: A list to indicate mutable states of ``nn_module``. For example,
        if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
        See the above `Flax Linen Intro` tutorial for more information.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import flax  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use flax to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `flax` to be able to use this feature. "
            "It can be installed with `pip install flax`."
        ) from e
    module_key = name + "$params"
    nn_params = numpyro.param(module_key)

    if mutable:
        nn_state = numpyro_mutable(name + "$state")
        assert nn_state is None or isinstance(nn_state, dict)
        assert (nn_state is None) == (nn_params is None)

    if nn_params is None:
        # feed in dummy data to init params
        args = (jnp.ones(input_shape),) if input_shape is not None else ()
        rng_key = numpyro.prng_key()
        # split rng_key into a dict of rng_kind: rng_key
        rngs = {}
        if apply_rng:
            assert isinstance(apply_rng, list)
            for kind in apply_rng:
                rng_key, subkey = random.split(rng_key)
                rngs[kind] = subkey
        rngs["params"] = rng_key

        nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
        if "params" not in nn_vars:
            raise ValueError(
                "Your nn_module does not have any parameter. Currently, it is not"
                " supported in NumPyro. Please make a github issue if you need"
                " that feature."
            )
        nn_params = nn_vars["params"]
        if mutable:
            nn_state = {k: v for k, v in nn_vars.items() if k != "params"}
            assert set(mutable) == set(nn_state)
            numpyro_mutable(name + "$state", nn_state)
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)

    def apply_with_state(params, *args, **kwargs):
        params = {"params": params, **nn_state}
        out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
        nn_state.update(**new_state)
        return out

    def apply_without_state(params, *args, **kwargs):
        return nn_module.apply({"params": params}, *args, **kwargs)

    apply_fn = apply_with_state if mutable else apply_without_state
    return partial(apply_fn, nn_params)
Esempio n. 3
0
 def model():
     x = numpyro.sample("x", dist.Normal(-1, 1))
     numpyro_mutable("x1p", x + 1)
Esempio n. 4
0
def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs):
    """
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    Given a haiku ``nn_module``, in haiku to evaluate the module with
    a given set of parameters, we use: ``nn_module.apply(params, None, x)``.
    In a NumPyro model, the pattern will be::

        net = haiku_module("net", nn_module)
        y = net(x)  # or y = net(rng_key, x)

    or with dropout layers::

        net = haiku_module("net", nn_module, apply_rng=True)
        rng_key = numpyro.prng_key()
        y = net(rng_key, x)

    :param str name: name of the module to be registered.
    :param nn_module: a `haiku` Module which has .init and .apply methods
    :type nn_module: haiku.Transformed or haiku.TransformedWithState
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param bool apply_rng: A flag to indicate if the returned callable requires
        an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults
        to False, which means no rng argument is needed. If this is True, the signature
        of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be
        ``nn(rng_key, x)`` (rather than ``nn(x)``).
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import haiku as hk  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use haiku to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `haiku` to be able to use this feature. "
            "It can be installed with `pip install dm-haiku`."
        ) from e

    if not apply_rng:
        nn_module = hk.without_apply_rng(nn_module)

    module_key = name + "$params"
    nn_params = numpyro.param(module_key)
    with_state = isinstance(nn_module, hk.TransformedWithState)
    if with_state:
        nn_state = numpyro_mutable(name + "$state")
        assert nn_state is None or isinstance(nn_state, dict)
        assert (nn_state is None) == (nn_params is None)

    if nn_params is None:
        args = (jnp.ones(input_shape),) if input_shape is not None else ()
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        if with_state:
            nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs)
            nn_state = dict(nn_state)
            numpyro_mutable(name + "$state", nn_state)
        else:
            nn_params = nn_module.init(rng_key, *args, **kwargs)
        # haiku init returns an immutable dict
        nn_params = hk.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)

    def apply_with_state(params, *args, **kwargs):
        out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
        nn_state.update(**new_state)
        return out

    apply_fn = apply_with_state if with_state else nn_module.apply
    return partial(apply_fn, nn_params)