示例#1
0
def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
    """
    A primitive to extract a python object from a :class:`~funsor.terms.Funsor`.

    :param ~funsor.terms.Funsor x: A funsor object
    :param OrderedDict name_to_dim: An optional inputs hint which maps
        dimension names from `x` to dimension positions of the returned value.
    :param int dim_type: Either 0, 1, or 2. This optional argument indicates
        a dimension should be treated as 'local', 'global', or 'visible',
        which can be used to interact with the global :class:`DimStack`.
    :return: A non-funsor equivalent to `x`.
    """
    name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim

    initial_msg = {
        'type':
        'to_data',
        'fn':
        lambda x, name_to_dim, dim_type: funsor.to_data(
            x, name_to_dim=name_to_dim),
        'args': (x, ),
        'kwargs': {
            "name_to_dim": name_to_dim,
            "dim_type": dim_type
        },
        'value':
        None,
        'mask':
        None,
    }

    msg = apply_stack(initial_msg)
    return msg['value']
示例#2
0
    def process_message(self, msg):
        if msg['type'] != 'sample':
            return
        if msg.get('_intervener_id', None) != self._intervener_id and \
                self.data.get(msg['name']) is not None:
            if msg.get('_intervener_id', None) is not None:
                warnings.warn(
                    "Attempting to intervene on variable {} multiple times,"
                    "this is almost certainly incorrect behavior".format(
                        msg['name']), RuntimeWarning)
            msg['_intervener_id'] = self._intervener_id

            # split node, avoid reapplying self recursively to new node
            new_msg = msg.copy()
            apply_stack(new_msg)

            intervention = self.data.get(msg['name'])
            msg['name'] = msg['name'] + "__CF"  # mangle old name
            msg['value'] = intervention
            msg['is_observed'] = True
            msg['stop'] = True
    def __init__(
        self,
        name: str,
        size: int,
        coords: Optional[ArrayLike] = None,
        dim: Optional[ArrayLike] = None,
    ):
        self.name: str = name
        self.size: int = size
        self.dim: int = -1 if dim is None else dim
        """int: Location in which to insert the dimension."""

        assert self.dim < 0
        if coords is None:
            coords = np.arange(self.size)
        self.coords: np.ndarray = np.array(coords)
        """numpy.ndarray: Coordinates for the dimension."""

        msg = self._get_message()
        apply_stack(msg)
        super().__init__()
示例#4
0
    def process_message(self, msg):
        if msg["type"] != "sample":
            return
        if (msg.get("_intervener_id", None) != self._intervener_id
                and self.data.get(msg["name"]) is not None):
            if msg.get("_intervener_id", None) is not None:
                warnings.warn(
                    "Attempting to intervene on variable {} multiple times,"
                    "this is almost certainly incorrect behavior".format(
                        msg["name"]),
                    RuntimeWarning,
                )
            msg["_intervener_id"] = self._intervener_id

            # split node, avoid reapplying self recursively to new node
            new_msg = msg.copy()
            apply_stack(new_msg)

            intervention = self.data.get(msg["name"])
            msg["name"] = msg["name"] + "__CF"  # mangle old name
            msg["value"] = intervention
            msg["is_observed"] = True
            msg["stop"] = True
示例#5
0
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
    """
    A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`.

    :param x: An object.
    :param funsor.domains.Domain output: An optional output hint to uniquely
        convert a data to a Funsor (e.g. when `x` is a string).
    :param OrderedDict dim_to_name: An optional mapping from negative
        batch dimensions to name strings.
    :param int dim_type: Either 0, 1, or 2. This optional argument indicates
        a dimension should be treated as 'local', 'global', or 'visible',
        which can be used to interact with the global :class:`DimStack`.
    :return: A Funsor equivalent to `x`.
    :rtype: funsor.terms.Funsor
    """
    dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name

    initial_msg = {
        'type':
        'to_funsor',
        'fn':
        lambda x, output, dim_to_name, dim_type: funsor.to_funsor(
            x, output=output, dim_to_name=dim_to_name),
        'args': (x, ),
        'kwargs': {
            "output": output,
            "dim_to_name": dim_to_name,
            "dim_type": dim_type
        },
        'value':
        None,
        'mask':
        None,
    }

    msg = apply_stack(initial_msg)
    return msg['value']
示例#6
0
文件: cond.py 项目: pyro-ppl/numpyro
def cond(pred, true_fun, false_fun, operand):
    """
    This primitive conditionally applies ``true_fun`` or ``false_fun``. See
    :func:`jax.lax.cond` for more information.

    **Usage**:

    .. doctest::

       >>> import numpyro
       >>> import numpyro.distributions as dist
       >>> from jax import random
       >>> from numpyro.contrib.control_flow import cond
       >>> from numpyro.infer import SVI, Trace_ELBO
       >>>
       >>> def model():
       ...     def true_fun(_):
       ...         return numpyro.sample("x", dist.Normal(20.0))
       ...
       ...     def false_fun(_):
       ...         return numpyro.sample("x", dist.Normal(0.0))
       ...
       ...     cluster = numpyro.sample("cluster", dist.Normal())
       ...     return cond(cluster > 0, true_fun, false_fun, None)
       >>>
       >>> def guide():
       ...     m1 = numpyro.param("m1", 10.0)
       ...     s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
       ...     m2 = numpyro.param("m2", 10.0)
       ...     s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)
       ...
       ...     def true_fun(_):
       ...         return numpyro.sample("x", dist.Normal(m1, s1))
       ...
       ...     def false_fun(_):
       ...         return numpyro.sample("x", dist.Normal(m2, s2))
       ...
       ...     cluster = numpyro.sample("cluster", dist.Normal())
       ...     return cond(cluster > 0, true_fun, false_fun, None)
       >>>
       >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
       >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)

    .. warning:: This is an experimental utility function that allows users to use
        JAX control flow with NumPyro's effect handlers. Currently, `sample` and
        `deterministic` sites within `true_fun` and `false_fun` are supported. If you
        notice that any effect handlers or distributions are unsupported, please file
        an issue.

    .. warning:: The ``cond`` primitive does not currently support enumeration and can
        not be used inside a ``numpyro.plate`` context.

    .. note:: All ``sample`` sites must belong to the same distribution class. For
        example the following is not supported

        .. code-block:: python

            cond(
                True,
                lambda _: numpyro.sample("x", dist.Normal()),
                lambda _: numpyro.sample("x", dist.Laplace()),
                None,
            )

    :param bool pred: Boolean scalar type indicating which branch function to apply
    :param callable true_fun: A function to be applied if ``pred`` is true.
    :param callable false_fun: A function to be applied if ``pred`` is false.
    :param operand: Operand input to either branch depending on ``pred``. This can
        be any JAX PyTree (e.g. list / dict of arrays).
    :return: Output of the applied branch function.
    """
    if not _PYRO_STACK:
        value, _ = cond_wrapper(pred, true_fun, false_fun, operand)
        return value

    initial_msg = {
        "type": "control_flow",
        "fn": cond_wrapper,
        "args": (pred, true_fun, false_fun, operand),
        "kwargs": {
            "rng_key": None,
            "substitute_stack": []
        },
        "value": None,
    }

    msg = apply_stack(initial_msg)
    value, pytree_trace = msg["value"]

    for msg in pytree_trace.trace.values():
        if msg["type"] == "plate":
            continue
        apply_stack(msg)

    return value
示例#7
0
def scan(f, init, xs, length=None, reverse=False, history=1):
    """
    This primitive scans a function over the leading array axes of
    `xs` while carrying along state. See :func:`jax.lax.scan` for more
    information.

    **Usage**:

    .. doctest::

       >>> import numpy as np
       >>> import numpyro
       >>> import numpyro.distributions as dist
       >>> from numpyro.contrib.control_flow import scan
       >>>
       >>> def gaussian_hmm(y=None, T=10):
       ...     def transition(x_prev, y_curr):
       ...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
       ...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
       ...         return x_curr, (x_curr, y_curr)
       ...
       ...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
       ...     _, (x, y) = scan(transition, x0, y, length=T)
       ...     return (x, y)
       >>>
       >>> # here we do some quick tests
       >>> with numpyro.handlers.seed(rng_seed=0):
       ...     x, y = gaussian_hmm(np.arange(10.))
       >>> assert x.shape == (10,) and y.shape == (10,)
       >>> assert np.all(y == np.arange(10))
       >>>
       >>> with numpyro.handlers.seed(rng_seed=0):  # generative
       ...     x, y = gaussian_hmm()
       >>> assert x.shape == (10,) and y.shape == (10,)

    .. warning:: This is an experimental utility function that allows users to use
        JAX control flow with NumPyro's effect handlers. Currently, `sample` and
        `deterministic` sites within the scan body `f` are supported. If you notice
        that any effect handlers or distributions are unsupported, please file an issue.

    .. note:: It is ambiguous to align `scan` dimension inside a `plate` context.
        So the following pattern won't be supported

        .. code-block:: python

            with numpyro.plate('N', 10):
                last, ys = scan(f, init, xs)

        All `plate` statements should be put inside `f`. For example, the corresponding
        working code is

        .. code-block:: python

            def g(*args, **kwargs):
                with numpyro.plate('N', 10):
                    return f(*arg, **kwargs)

            last, ys = scan(g, init, xs)

    .. note:: Nested scan is currently not supported.

    .. note:: We can scan over discrete latent variables in `f`. The joint density is
        evaluated using parallel-scan (reference [1]) over time dimension, which
        reduces parallel complexity to `O(log(length))`.

        A :class:`~numpyro.handlers.trace` of `scan` with discrete latent
        variables will contain the following sites:

            + init sites: those sites belong to the first `history` traces of `f`.
                Sites at the `i`-th trace will have name prefixed with
                `'_PREV_' * (2 * history - 1 - i)`.
            + scanned sites: those sites collect the values of the remaining scan
                loop over `f`. An addition time dimension `_time_foo` will be
                added to those sites, where `foo` is the name of the first site
                appeared in `f`.

        Not all transition functions `f` are supported. All of the restrictions from
        Pyro's enumeration tutorial [2] still apply here. In addition, there should
        not have any site outside of `scan` depend on the first output of `scan`
        (the last carry value).

    ** References **

    1. *Temporal Parallelization of Bayesian Smoothers*,
       Simo Sarkka, Angel F. Garcia-Fernandez
       (https://arxiv.org/abs/1905.13002)

    2. *Inference with Discrete Latent Variables*
       (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)

    :param callable f: a function to be scanned.
    :param init: the initial carrying state
    :param xs: the values over which we scan along the leading axis. This can
        be any JAX pytree (e.g. list/dict of arrays).
    :param length: optional value specifying the length of `xs`
        but can be used when `xs` is an empty pytree (e.g. None)
    :param bool reverse: optional boolean specifying whether to run the scan iteration
        forward (the default) or in reverse
    :param int history: The number of previous contexts visible from the current context.
        Defaults to 1. If zero, this is similar to :class:`numpyro.plate`.
    :return: output of scan, quoted from :func:`jax.lax.scan` docs:
        "pair of type (c, [b]) where the first element represents the final loop
        carry value and the second element represents the stacked outputs of the
        second output of f when scanned over the leading axis of the inputs".
    """
    # if there are no active Messengers, we just run and return it as expected:
    if not _PYRO_STACK:
        (length, rng_key, carry), (pytree_trace,
                                   ys) = scan_wrapper(f,
                                                      init,
                                                      xs,
                                                      length=length,
                                                      reverse=reverse)
        return carry, ys
    else:
        # Otherwise, we initialize a message...
        initial_msg = {
            "type": "control_flow",
            "fn": scan_wrapper,
            "args": (f, init, xs, length, reverse),
            "kwargs": {
                "rng_key": None,
                "substitute_stack": [],
                "history": history
            },
            "value": None,
        }

        # ...and use apply_stack to send it to the Messengers
        msg = apply_stack(initial_msg)
        (length, rng_key, carry), (pytree_trace, ys) = msg["value"]

    if not msg["kwargs"].get("enum", False):
        for msg in pytree_trace.trace.values():
            apply_stack(msg)
    else:
        from numpyro.contrib.funsor import to_funsor
        from numpyro.contrib.funsor.enum_messenger import LocalNamedMessenger

        for msg in pytree_trace.trace.values():
            with LocalNamedMessenger():
                dim_to_name = msg["infer"].get("dim_to_name")
                to_funsor(
                    msg["value"],
                    dim_to_name=OrderedDict([(k, dim_to_name[k])
                                             for k in sorted(dim_to_name)]),
                )
                apply_stack(msg)

    return carry, ys
示例#8
0
def scan(f, init, xs, length=None, reverse=False):
    """
    This primitive scans a function over the leading array axes of
    `xs` while carrying along state. See :func:`jax.lax.scan` for more
    information.

    **Usage**::

    .. doctest::

       >>> import numpy as np
       >>> import numpyro
       >>> import numpyro.distributions as dist
       >>> from numpyro.contrib.control_flow import scan
       >>>
       >>> def gaussian_hmm(y=None, T=10):
       ...     def transition(x_prev, y_curr):
       ...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
       ...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
       ...         return x_curr, (x_curr, y_curr)
       ...
       ...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
       ...     _, (x, y) = scan(transition, x0, y, length=T)
       ...     return (x, y)
       >>>
       >>> # here we do some quick tests
       >>> with numpyro.handlers.seed(rng_seed=0):
       ...     x, y = gaussian_hmm(np.arange(10.))
       >>> assert x.shape == (10,) and y.shape == (10,)
       >>> assert np.all(y == np.arange(10))
       >>>
       >>> with numpyro.handlers.seed(rng_seed=0):  # generative
       ...     x, y = gaussian_hmm()
       >>> assert x.shape == (10,) and y.shape == (10,)

    .. warning:: This is an experimental utility function that allows users to use
        JAX control flow with NumPyro's effect handlers. Currently, `sample` and
        `deterministic` sites within the scan body `f` are supported. If you notice
        that any effect handlers or distributions are unsupported, please file an issue.

    .. note:: It is ambiguous to align `scan` dimension inside a `plate` context.
        So the following pattern won't be supported

            with numpyro.plate('N', 10):
                last, ys = scan(f, init, xs)

        All `plate` statements should be put inside `f`. For example, the corresponding
        working code is

            def g(*args, **kwargs):
                with numpyro.plate('N', 10):
                    return f(*arg, **kwargs)

            last, ys = scan(g, init, xs)

    :param callable f: a function to be scanned.
    :param init: the initial carrying state
    :param xs: the values over which we scan along the leading axis. This can
        be any JAX pytree (e.g. list/dict of arrays).
    :param length: optional value specifying the length of `xs`
        but can be used when `xs` is an empty pytree (e.g. None)
    :param bool reverse: optional boolean specifying whether to run the scan iteration
        forward (the default) or in reverse
    :return: output of scan, quoted from :func:`jax.lax.scan` docs:
        "pair of type (c, [b]) where the first element represents the final loop
        carry value and the second element represents the stacked outputs of the
        second output of f when scanned over the leading axis of the inputs".
    """
    # if there are no active Messengers, we just run and return it as expected:
    if not _PYRO_STACK:
        (length, rng_key, carry), (pytree_trace,
                                   ys) = scan_wrapper(f,
                                                      init,
                                                      xs,
                                                      length=length,
                                                      reverse=reverse)
    else:
        # Otherwise, we initialize a message...
        initial_msg = {
            'type': 'control_flow',
            'fn': scan_wrapper,
            'args': (f, init, xs, length, reverse),
            'kwargs': {
                'rng_key': None,
                'substitute_stack': []
            },
            'value': None,
        }

        # ...and use apply_stack to send it to the Messengers
        msg = apply_stack(initial_msg)
        (length, rng_key, carry), (pytree_trace, ys) = msg['value']

    for msg in pytree_trace.trace.values():
        apply_stack(msg)

    return carry, ys