Beispiel #1
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real'))
         v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
         fv1 = to_funsor(v1, reals())
         fv2 = to_funsor(v2, reals())
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
Beispiel #2
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
             fv2 = to_funsor(v2, reals())
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
Beispiel #3
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(
             Tensor(jnp.ones(2), OrderedDict([(str(i), Bint[2])]), "real"))
         v2 = to_data(
             Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real"))
         fv1 = to_funsor(v1, Real)
         fv2 = to_funsor(v2, Real)
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2, )
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print("a", v2.shape)  # shapes should stay the same
         print("a", fv2.inputs)
Beispiel #4
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(
                 Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]),
                        "real"))
             fv2 = to_funsor(v2, Real)
             assert v2.shape == (2, )
             print("a", v2.shape)
             print("a", fv2.inputs)
Beispiel #5
0
def scan(f, init, xs, length=None, reverse=False, history=1):
    """
    This primitive scans a function over the leading array axes of
    `xs` while carrying along state. See :func:`jax.lax.scan` for more
    information.

    **Usage**:

    .. doctest::

       >>> import numpy as np
       >>> import numpyro
       >>> import numpyro.distributions as dist
       >>> from numpyro.contrib.control_flow import scan
       >>>
       >>> def gaussian_hmm(y=None, T=10):
       ...     def transition(x_prev, y_curr):
       ...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
       ...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
       ...         return x_curr, (x_curr, y_curr)
       ...
       ...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
       ...     _, (x, y) = scan(transition, x0, y, length=T)
       ...     return (x, y)
       >>>
       >>> # here we do some quick tests
       >>> with numpyro.handlers.seed(rng_seed=0):
       ...     x, y = gaussian_hmm(np.arange(10.))
       >>> assert x.shape == (10,) and y.shape == (10,)
       >>> assert np.all(y == np.arange(10))
       >>>
       >>> with numpyro.handlers.seed(rng_seed=0):  # generative
       ...     x, y = gaussian_hmm()
       >>> assert x.shape == (10,) and y.shape == (10,)

    .. warning:: This is an experimental utility function that allows users to use
        JAX control flow with NumPyro's effect handlers. Currently, `sample` and
        `deterministic` sites within the scan body `f` are supported. If you notice
        that any effect handlers or distributions are unsupported, please file an issue.

    .. note:: It is ambiguous to align `scan` dimension inside a `plate` context.
        So the following pattern won't be supported

        .. code-block:: python

            with numpyro.plate('N', 10):
                last, ys = scan(f, init, xs)

        All `plate` statements should be put inside `f`. For example, the corresponding
        working code is

        .. code-block:: python

            def g(*args, **kwargs):
                with numpyro.plate('N', 10):
                    return f(*arg, **kwargs)

            last, ys = scan(g, init, xs)

    .. note:: Nested scan is currently not supported.

    .. note:: We can scan over discrete latent variables in `f`. The joint density is
        evaluated using parallel-scan (reference [1]) over time dimension, which
        reduces parallel complexity to `O(log(length))`.

        A :class:`~numpyro.handlers.trace` of `scan` with discrete latent
        variables will contain the following sites:

            + init sites: those sites belong to the first `history` traces of `f`.
                Sites at the `i`-th trace will have name prefixed with
                `'_PREV_' * (2 * history - 1 - i)`.
            + scanned sites: those sites collect the values of the remaining scan
                loop over `f`. An addition time dimension `_time_foo` will be
                added to those sites, where `foo` is the name of the first site
                appeared in `f`.

        Not all transition functions `f` are supported. All of the restrictions from
        Pyro's enumeration tutorial [2] still apply here. In addition, there should
        not have any site outside of `scan` depend on the first output of `scan`
        (the last carry value).

    ** References **

    1. *Temporal Parallelization of Bayesian Smoothers*,
       Simo Sarkka, Angel F. Garcia-Fernandez
       (https://arxiv.org/abs/1905.13002)

    2. *Inference with Discrete Latent Variables*
       (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)

    :param callable f: a function to be scanned.
    :param init: the initial carrying state
    :param xs: the values over which we scan along the leading axis. This can
        be any JAX pytree (e.g. list/dict of arrays).
    :param length: optional value specifying the length of `xs`
        but can be used when `xs` is an empty pytree (e.g. None)
    :param bool reverse: optional boolean specifying whether to run the scan iteration
        forward (the default) or in reverse
    :param int history: The number of previous contexts visible from the current context.
        Defaults to 1. If zero, this is similar to :class:`numpyro.plate`.
    :return: output of scan, quoted from :func:`jax.lax.scan` docs:
        "pair of type (c, [b]) where the first element represents the final loop
        carry value and the second element represents the stacked outputs of the
        second output of f when scanned over the leading axis of the inputs".
    """
    # if there are no active Messengers, we just run and return it as expected:
    if not _PYRO_STACK:
        (length, rng_key, carry), (pytree_trace,
                                   ys) = scan_wrapper(f,
                                                      init,
                                                      xs,
                                                      length=length,
                                                      reverse=reverse)
        return carry, ys
    else:
        # Otherwise, we initialize a message...
        initial_msg = {
            "type": "control_flow",
            "fn": scan_wrapper,
            "args": (f, init, xs, length, reverse),
            "kwargs": {
                "rng_key": None,
                "substitute_stack": [],
                "history": history
            },
            "value": None,
        }

        # ...and use apply_stack to send it to the Messengers
        msg = apply_stack(initial_msg)
        (length, rng_key, carry), (pytree_trace, ys) = msg["value"]

    if not msg["kwargs"].get("enum", False):
        for msg in pytree_trace.trace.values():
            apply_stack(msg)
    else:
        from numpyro.contrib.funsor import to_funsor
        from numpyro.contrib.funsor.enum_messenger import LocalNamedMessenger

        for msg in pytree_trace.trace.values():
            with LocalNamedMessenger():
                dim_to_name = msg["infer"].get("dim_to_name")
                to_funsor(
                    msg["value"],
                    dim_to_name=OrderedDict([(k, dim_to_name[k])
                                             for k in sorted(dim_to_name)]),
                )
                apply_stack(msg)

    return carry, ys