def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): """ A primitive to extract a python object from a :class:`~funsor.terms.Funsor`. :param ~funsor.terms.Funsor x: A funsor object :param OrderedDict name_to_dim: An optional inputs hint which maps dimension names from `x` to dimension positions of the returned value. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A non-funsor equivalent to `x`. """ name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim initial_msg = { 'type': 'to_data', 'fn': lambda x, name_to_dim, dim_type: funsor.to_data( x, name_to_dim=name_to_dim), 'args': (x, ), 'kwargs': { "name_to_dim": name_to_dim, "dim_type": dim_type }, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']
def process_message(self, msg): if msg['type'] != 'sample': return if msg.get('_intervener_id', None) != self._intervener_id and \ self.data.get(msg['name']) is not None: if msg.get('_intervener_id', None) is not None: warnings.warn( "Attempting to intervene on variable {} multiple times," "this is almost certainly incorrect behavior".format( msg['name']), RuntimeWarning) msg['_intervener_id'] = self._intervener_id # split node, avoid reapplying self recursively to new node new_msg = msg.copy() apply_stack(new_msg) intervention = self.data.get(msg['name']) msg['name'] = msg['name'] + "__CF" # mangle old name msg['value'] = intervention msg['is_observed'] = True msg['stop'] = True
def __init__( self, name: str, size: int, coords: Optional[ArrayLike] = None, dim: Optional[ArrayLike] = None, ): self.name: str = name self.size: int = size self.dim: int = -1 if dim is None else dim """int: Location in which to insert the dimension.""" assert self.dim < 0 if coords is None: coords = np.arange(self.size) self.coords: np.ndarray = np.array(coords) """numpy.ndarray: Coordinates for the dimension.""" msg = self._get_message() apply_stack(msg) super().__init__()
def process_message(self, msg): if msg["type"] != "sample": return if (msg.get("_intervener_id", None) != self._intervener_id and self.data.get(msg["name"]) is not None): if msg.get("_intervener_id", None) is not None: warnings.warn( "Attempting to intervene on variable {} multiple times," "this is almost certainly incorrect behavior".format( msg["name"]), RuntimeWarning, ) msg["_intervener_id"] = self._intervener_id # split node, avoid reapplying self recursively to new node new_msg = msg.copy() apply_stack(new_msg) intervention = self.data.get(msg["name"]) msg["name"] = msg["name"] + "__CF" # mangle old name msg["value"] = intervention msg["is_observed"] = True msg["stop"] = True
def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): """ A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`. :param x: An object. :param funsor.domains.Domain output: An optional output hint to uniquely convert a data to a Funsor (e.g. when `x` is a string). :param OrderedDict dim_to_name: An optional mapping from negative batch dimensions to name strings. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A Funsor equivalent to `x`. :rtype: funsor.terms.Funsor """ dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name initial_msg = { 'type': 'to_funsor', 'fn': lambda x, output, dim_to_name, dim_type: funsor.to_funsor( x, output=output, dim_to_name=dim_to_name), 'args': (x, ), 'kwargs': { "output": output, "dim_to_name": dim_to_name, "dim_type": dim_type }, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']
def cond(pred, true_fun, false_fun, operand): """ This primitive conditionally applies ``true_fun`` or ``false_fun``. See :func:`jax.lax.cond` for more information. **Usage**: .. doctest:: >>> import numpyro >>> import numpyro.distributions as dist >>> from jax import random >>> from numpyro.contrib.control_flow import cond >>> from numpyro.infer import SVI, Trace_ELBO >>> >>> def model(): ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(20.0)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(0.0)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> def guide(): ... m1 = numpyro.param("m1", 10.0) ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) ... m2 = numpyro.param("m2", 10.0) ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) ... ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(m1, s1)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(m2, s2)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and `deterministic` sites within `true_fun` and `false_fun` are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue. .. warning:: The ``cond`` primitive does not currently support enumeration and can not be used inside a ``numpyro.plate`` context. .. note:: All ``sample`` sites must belong to the same distribution class. For example the following is not supported .. code-block:: python cond( True, lambda _: numpyro.sample("x", dist.Normal()), lambda _: numpyro.sample("x", dist.Laplace()), None, ) :param bool pred: Boolean scalar type indicating which branch function to apply :param callable true_fun: A function to be applied if ``pred`` is true. :param callable false_fun: A function to be applied if ``pred`` is false. :param operand: Operand input to either branch depending on ``pred``. This can be any JAX PyTree (e.g. list / dict of arrays). :return: Output of the applied branch function. """ if not _PYRO_STACK: value, _ = cond_wrapper(pred, true_fun, false_fun, operand) return value initial_msg = { "type": "control_flow", "fn": cond_wrapper, "args": (pred, true_fun, false_fun, operand), "kwargs": { "rng_key": None, "substitute_stack": [] }, "value": None, } msg = apply_stack(initial_msg) value, pytree_trace = msg["value"] for msg in pytree_trace.trace.values(): if msg["type"] == "plate": continue apply_stack(msg) return value
def scan(f, init, xs, length=None, reverse=False, history=1): """ This primitive scans a function over the leading array axes of `xs` while carrying along state. See :func:`jax.lax.scan` for more information. **Usage**: .. doctest:: >>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and `deterministic` sites within the scan body `f` are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue. .. note:: It is ambiguous to align `scan` dimension inside a `plate` context. So the following pattern won't be supported .. code-block:: python with numpyro.plate('N', 10): last, ys = scan(f, init, xs) All `plate` statements should be put inside `f`. For example, the corresponding working code is .. code-block:: python def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs) .. note:: Nested scan is currently not supported. .. note:: We can scan over discrete latent variables in `f`. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to `O(log(length))`. A :class:`~numpyro.handlers.trace` of `scan` with discrete latent variables will contain the following sites: + init sites: those sites belong to the first `history` traces of `f`. Sites at the `i`-th trace will have name prefixed with `'_PREV_' * (2 * history - 1 - i)`. + scanned sites: those sites collect the values of the remaining scan loop over `f`. An addition time dimension `_time_foo` will be added to those sites, where `foo` is the name of the first site appeared in `f`. Not all transition functions `f` are supported. All of the restrictions from Pyro's enumeration tutorial [2] still apply here. In addition, there should not have any site outside of `scan` depend on the first output of `scan` (the last carry value). ** References ** 1. *Temporal Parallelization of Bayesian Smoothers*, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002) 2. *Inference with Discrete Latent Variables* (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates) :param callable f: a function to be scanned. :param init: the initial carrying state :param xs: the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays). :param length: optional value specifying the length of `xs` but can be used when `xs` is an empty pytree (e.g. None) :param bool reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse :param int history: The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to :class:`numpyro.plate`. :return: output of scan, quoted from :func:`jax.lax.scan` docs: "pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs". """ # if there are no active Messengers, we just run and return it as expected: if not _PYRO_STACK: (length, rng_key, carry), (pytree_trace, ys) = scan_wrapper(f, init, xs, length=length, reverse=reverse) return carry, ys else: # Otherwise, we initialize a message... initial_msg = { "type": "control_flow", "fn": scan_wrapper, "args": (f, init, xs, length, reverse), "kwargs": { "rng_key": None, "substitute_stack": [], "history": history }, "value": None, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) (length, rng_key, carry), (pytree_trace, ys) = msg["value"] if not msg["kwargs"].get("enum", False): for msg in pytree_trace.trace.values(): apply_stack(msg) else: from numpyro.contrib.funsor import to_funsor from numpyro.contrib.funsor.enum_messenger import LocalNamedMessenger for msg in pytree_trace.trace.values(): with LocalNamedMessenger(): dim_to_name = msg["infer"].get("dim_to_name") to_funsor( msg["value"], dim_to_name=OrderedDict([(k, dim_to_name[k]) for k in sorted(dim_to_name)]), ) apply_stack(msg) return carry, ys
def scan(f, init, xs, length=None, reverse=False): """ This primitive scans a function over the leading array axes of `xs` while carrying along state. See :func:`jax.lax.scan` for more information. **Usage**:: .. doctest:: >>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and `deterministic` sites within the scan body `f` are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue. .. note:: It is ambiguous to align `scan` dimension inside a `plate` context. So the following pattern won't be supported with numpyro.plate('N', 10): last, ys = scan(f, init, xs) All `plate` statements should be put inside `f`. For example, the corresponding working code is def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs) :param callable f: a function to be scanned. :param init: the initial carrying state :param xs: the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays). :param length: optional value specifying the length of `xs` but can be used when `xs` is an empty pytree (e.g. None) :param bool reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse :return: output of scan, quoted from :func:`jax.lax.scan` docs: "pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs". """ # if there are no active Messengers, we just run and return it as expected: if not _PYRO_STACK: (length, rng_key, carry), (pytree_trace, ys) = scan_wrapper(f, init, xs, length=length, reverse=reverse) else: # Otherwise, we initialize a message... initial_msg = { 'type': 'control_flow', 'fn': scan_wrapper, 'args': (f, init, xs, length, reverse), 'kwargs': { 'rng_key': None, 'substitute_stack': [] }, 'value': None, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) (length, rng_key, carry), (pytree_trace, ys) = msg['value'] for msg in pytree_trace.trace.values(): apply_stack(msg) return carry, ys