Esempio n. 1
0
def helpful_support_errors(site, raise_warnings=False):
    name = site["name"]
    support = getattr(site["fn"], "support", None)
    if isinstance(support, constraints.independent):
        support = support.base_constraint

    # Warnings
    if raise_warnings:
        if support is constraints.circular:
            msg = (
                f"Continuous inference poorly handles circular sample site '{name}'. "
                + "Consider using VonMises distribution together with " +
                "a reparameterizer, e.g. " +
                f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})."
            )
            warnings.warn(msg, UserWarning, stacklevel=find_stack_level())

    # Exceptions
    try:
        yield
    except NotImplementedError as e:
        support_name = repr(support).lower()
        if "integer" in support_name or "boolean" in support_name:
            # TODO: mention enumeration when it is supported in SVI
            raise ValueError(
                f"Continuous inference cannot handle discrete sample site '{name}'."
            )
        if "sphere" in support_name:
            raise ValueError(
                f"Continuous inference cannot handle spherical sample site '{name}'. "
                "Consider using ProjectedNormal distribution together with "
                "a reparameterizer, e.g. "
                f"numpyro.handlers.reparam(config={{'{name}': ProjectedNormalReparam()}})."
            )
        raise e from None
Esempio n. 2
0
def _load_higgs(num_datapoints):
    warnings.warn(
        "Higgs is a 2.6 GB dataset",
        stacklevel=find_stack_level(),
    )
    _download(HIGGS)

    file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz")
    with io.TextIOWrapper(gzip.open(file_path, "rb")) as f:
        csv_reader = csv.reader(f, delimiter=",", quoting=csv.QUOTE_NONE)
        obs = []
        data = []
        for i, row in enumerate(csv_reader):
            obs.append(int(float(row[0])))
            data.append([float(v) for v in row[1:]])
            if num_datapoints and i > num_datapoints:
                break
    obs = np.stack(obs)
    data = np.stack(data)
    (n, ) = obs.shape

    return {
        "train": (data[:-(n // 20)], obs[:-(n // 20)]),
        "test": (data[-(n // 20):], obs[-(n // 20):]),
    }  # standard split -500_000: as test
Esempio n. 3
0
def init_to_median(site=None, num_samples=15):
    """
    Initialize to the prior median. For priors with no `.sample` method implemented,
    we defer to the :func:`init_to_uniform` strategy.

    :param int num_samples: number of prior points to calculate median.
    """
    if site is None:
        return partial(init_to_median, num_samples=num_samples)

    if (site["type"] == "sample" and not site["is_observed"]
            and not site["fn"].support.is_discrete):
        if site["value"] is not None:
            warnings.warn(
                f"init_to_median() skipping initialization of site '{site['name']}'"
                " which already stores a value.",
                stacklevel=find_stack_level(),
            )
            return site["value"]

        rng_key = site["kwargs"].get("rng_key")
        sample_shape = site["kwargs"].get("sample_shape")
        try:
            samples = site["fn"](sample_shape=(num_samples, ) + sample_shape,
                                 rng_key=rng_key)
            return jnp.median(samples, axis=0)
        except NotImplementedError:
            return init_to_uniform(site)
Esempio n. 4
0
 def _inverse(self, y):
     warnings.warn(
         "AbsTransform is not a bijective transform."
         " The inverse of `y` will be `y`.",
         stacklevel=find_stack_level(),
     )
     return y
Esempio n. 5
0
def init_to_uniform(site=None, radius=2):
    """
    Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.

    :param float radius: specifies the range to draw an initial point in the unconstrained domain.
    """
    if site is None:
        return partial(init_to_uniform, radius=radius)

    if (site["type"] == "sample" and not site["is_observed"]
            and not site["fn"].support.is_discrete):
        if site["value"] is not None:
            warnings.warn(
                f"init_to_uniform() skipping initialization of site '{site['name']}'"
                " which already stores a value.",
                stacklevel=find_stack_level(),
            )
            return site["value"]

        # XXX: we import here to avoid circular import
        from numpyro.infer.util import helpful_support_errors

        rng_key = site["kwargs"].get("rng_key")
        sample_shape = site["kwargs"].get("sample_shape")

        with helpful_support_errors(site):
            transform = biject_to(site["fn"].support)
        unconstrained_shape = transform.inverse_shape(site["fn"].shape())
        unconstrained_samples = dist.Uniform(-radius, radius)(
            rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape)
        return transform(unconstrained_samples)
Esempio n. 6
0
 def event_dim(self):
     warnings.warn(
         "transform.event_dim is deprecated. Please use Transform.domain.event_dim to "
         "get input event dim or Transform.codomain.event_dim to get output event dim.",
         FutureWarning,
         stacklevel=find_stack_level(),
     )
     return self.domain.event_dim
Esempio n. 7
0
 def __init__(self):
     warnings.warn(
         "PRNGIdentity distribution is deprecated. To get a random "
         "PRNG key, you can use `numpyro.prng_key()` instead.",
         FutureWarning,
         stacklevel=find_stack_level(),
     )
     super(PRNGIdentity, self).__init__(event_shape=(2, ))
Esempio n. 8
0
 def __init__(self, domain=constraints.lower_cholesky):
     warnings.warn(
         "InvCholeskyTransform is deprecated. Please use CholeskyTransform"
         " or CorrMatrixCholeskyTransform instead.",
         FutureWarning,
         stacklevel=find_stack_level(),
     )
     assert domain in [constraints.lower_cholesky, constraints.corr_cholesky]
     self.domain = domain
Esempio n. 9
0
def get_param_store():
    warnings.warn(
        "A limited parameter store is provided for compatibility with Pyro. "
        "Value of SVI parameters should be obtained via SVI.get_params() method.",
        category=UnsupportedAPIWarning,
        stacklevel=find_stack_level(),
    )
    # Return an empty dict for compatibility
    return _PARAM_STORE
Esempio n. 10
0
 def init(self, *args, **kwargs):
     warnings.warn(
         "Importing distributions from numpyro.contrib.tfp.distributions is "
         "deprecated. You should import distributions directly from "
         "tensorflow_probability.substrates.jax.distributions instead.",
         FutureWarning,
         stacklevel=find_stack_level(),
     )
     self.tfp_dist = tfd_class(*args, **kwargs)
Esempio n. 11
0
 def __init__(
     self,
     sampler,
     *,
     num_warmup,
     num_samples,
     num_chains=1,
     thinning=1,
     postprocess_fn=None,
     chain_method="parallel",
     progress_bar=True,
     jit_model_args=False,
 ):
     self.sampler = sampler
     self._sample_field = sampler.sample_field
     self._default_fields = sampler.default_fields
     self.num_warmup = num_warmup
     self.num_samples = num_samples
     self.num_chains = num_chains
     if not isinstance(thinning, int) or thinning < 1:
         raise ValueError("thinning must be a positive integer")
     self.thinning = thinning
     self.postprocess_fn = postprocess_fn
     if chain_method not in ["parallel", "vectorized", "sequential"]:
         raise ValueError(
             "Only supporting the following methods to draw chains:"
             ' "sequential", "parallel", or "vectorized"'
         )
     if chain_method == "parallel" and local_device_count() < self.num_chains:
         chain_method = "sequential"
         warnings.warn(
             "There are not enough devices to run parallel chains: expected {} but got {}."
             " Chains will be drawn sequentially. If you are running MCMC in CPU,"
             " consider using `numpyro.set_host_device_count({})` at the beginning"
             " of your program. You can double-check how many devices are available in"
             " your system using `jax.local_device_count()`.".format(
                 self.num_chains, local_device_count(), self.num_chains
             ),
             stacklevel=find_stack_level(),
         )
     self.chain_method = chain_method
     self.progress_bar = progress_bar
     if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ:
         self.progress_bar = False
     self._jit_model_args = jit_model_args
     self._states = None
     self._states_flat = None
     # HMCState returned by last run
     self._last_state = None
     # HMCState returned by last warmup
     self._warmup_state = None
     # HMCState returned by hmc.init_kernel
     self._init_state_cache = {}
     self._cache = {}
     self._collection_params = {}
     self._set_collection_params()
Esempio n. 12
0
def multinomial_goodness_of_fit(probs,
                                counts,
                                *,
                                total_count=None,
                                plot=False):
    """
    Pearson's chi^2 test, on possibly truncated data.
    https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test

    :param numpy.ndarray probs: Vector of probabilities.
    :param numpy.ndarray counts: Vector of counts.
    :param int total_count: Optional total count in case data is truncated,
        otherwise None.
    :param bool plot: Whether to print a histogram. Defaults to False.
    :returns: p-value of truncated multinomial sample.
    :rtype: float
    """
    probs = jax.lax.stop_gradient(probs)
    assert len(probs.shape) == 1
    assert probs.shape == counts.shape
    if total_count is None:
        truncated = False
        total_count = int(counts.sum())
    else:
        truncated = True
        assert total_count >= counts.sum()
    if plot:
        print_histogram(probs, counts)

    chi_squared = 0
    dof = 0
    for p, c in zip(probs.tolist(), counts.tolist()):
        if abs(p - 1) < 1e-8:
            return 1 if c == total_count else 0
        assert p < 1, f"bad probability: {p:g}"
        if p > 0:
            mean = total_count * p
            variance = total_count * p * (1 - p)
            if not (variance > 1):
                raise InvalidTest(
                    "Goodness of fit is inaccurate; use more samples")
            chi_squared += (c - mean)**2 / variance
            dof += 1
        else:
            warnings.warn(
                "Zero probability in goodness-of-fit test",
                stacklevel=find_stack_level(),
            )
            if c > 0:
                return math.inf

    if not truncated:
        dof -= 1

    survival = _chi2sf(chi_squared, dof)
    return survival
Esempio n. 13
0
 def _validate_sample(self, value):
     mask = self.support(value)
     if not_jax_tracer(mask):
         if not np.all(mask):
             warnings.warn(
                 "Out-of-support values provided to log prob method. "
                 "The value argument should be within the support.",
                 stacklevel=find_stack_level(),
             )
     return mask
Esempio n. 14
0
 def __init__(self, fn=None, trace=None, guide_trace=None):
     if guide_trace is not None:
         warnings.warn(
             "`guide_trace` argument is deprecated. Please replace it by `trace`.",
             FutureWarning,
             stacklevel=find_stack_level(),
         )
     if guide_trace is not None:
         trace = guide_trace
     assert trace is not None
     self.trace = trace
     super(replay, self).__init__(fn)
Esempio n. 15
0
 def _inverse(self, y):
     leading_dims = [
         v.shape[0] if jnp.ndim(v) > 0 else 0 for v in tree_flatten(y)[0]
     ]
     d0 = leading_dims[0]
     not_scalar = d0 > 0 or len(leading_dims) > 1
     if not_scalar and all(d == d0 for d in leading_dims[1:]):
         warnings.warn(
             "UnpackTransform.inv might lead to an unexpected behavior because it"
             " cannot transform a batch of unpacked arrays.",
             stacklevel=find_stack_level(),
         )
     return ravel_pytree(y)[0]
Esempio n. 16
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (site["type"] == "sample" and (not site["is_observed"])
                  and site["fn"].support.is_discrete
                  and not self.loss.can_infer_discrete):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables",
                    stacklevel=find_stack_level(),
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)
Esempio n. 17
0
def _get_model_transforms(model, model_args=(), model_kwargs=None):
    model_kwargs = {} if model_kwargs is None else model_kwargs
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    # model code may need to be replayed in the presence of deterministic sites
    replay_model = False
    has_enumerate_support = False
    for k, v in model_trace.items():
        if v["type"] == "sample" and not v["is_observed"]:
            if v["fn"].support.is_discrete:
                enum_type = v["infer"].get("enumerate")
                if enum_type is not None and (enum_type != "parallel"):
                    raise RuntimeError(
                        "This algorithm might only work for discrete sites with"
                        f" enumerate marked 'parallel'. But the site {k} is marked"
                        f" as '{enum_type}'.")
                has_enumerate_support = True
                if not v["fn"].has_enumerate_support:
                    dist_name = type(v["fn"]).__name__
                    raise RuntimeError(
                        "This algorithm might only work for discrete sites with"
                        f" enumerate support. But the {dist_name} distribution at"
                        f" site {k} does not have enumerate support.")
                if enum_type is None:
                    warnings.warn(
                        "Some algorithms will automatically enumerate the discrete"
                        f" latent site {k} of your model. In the future,"
                        " enumerated sites need to be marked with"
                        " `infer={'enumerate': 'parallel'}`.",
                        FutureWarning,
                        stacklevel=find_stack_level(),
                    )
            else:
                support = v["fn"].support
                with helpful_support_errors(v, raise_warnings=True):
                    inv_transforms[k] = biject_to(support)
                # XXX: the following code filters out most situations with dynamic supports
                args = ()
                if isinstance(support, constraints._GreaterThan):
                    args = ("lower_bound", )
                elif isinstance(support, constraints._Interval):
                    args = ("lower_bound", "upper_bound")
                for arg in args:
                    if not isinstance(getattr(support, arg), (int, float)):
                        replay_model = True
        elif v["type"] == "deterministic":
            replay_model = True
    return inv_transforms, replay_model, has_enumerate_support, model_trace
Esempio n. 18
0
def _check_mean_field_requirement(model_trace, guide_trace):
    """
    Checks that the guide and model sample sites are ordered identically.
    This is sufficient but not necessary for correctness.
    """
    model_sites = [
        name for name, site in model_trace.items()
        if site["type"] == "sample" and name in guide_trace
    ]
    guide_sites = [
        name for name, site in guide_trace.items()
        if site["type"] == "sample" and name in model_trace
    ]
    assert set(model_sites) == set(guide_sites)
    if model_sites != guide_sites:
        warnings.warn(
            "Failed to verify mean field restriction on the guide. "
            "To eliminate this warning, ensure model and guide sites "
            "occur in the same order.\n" + "Model sites:\n  " +
            "\n  ".join(model_sites) + "Guide sites:\n  " +
            "\n  ".join(guide_sites),
            stacklevel=find_stack_level(),
        ),
Esempio n. 19
0
    def process_message(self, msg):
        if msg["type"] != "sample":
            return
        if (msg.get("_intervener_id", None) != self._intervener_id
                and self.data.get(msg["name"]) is not None):
            if msg.get("_intervener_id", None) is not None:
                warnings.warn(
                    "Attempting to intervene on variable {} multiple times,"
                    "this is almost certainly incorrect behavior".format(
                        msg["name"]),
                    RuntimeWarning,
                    stacklevel=find_stack_level(),
                )
            msg["_intervener_id"] = self._intervener_id

            # split node, avoid reapplying self recursively to new node
            new_msg = msg.copy()
            apply_stack(new_msg)

            intervention = self.data.get(msg["name"])
            msg["name"] = msg["name"] + "__CF"  # mangle old name
            msg["value"] = intervention
            msg["is_observed"] = True
            msg["stop"] = True
Esempio n. 20
0
 def _subsample(name, size, subsample_size, dim):
     msg = {
         "type": "plate",
         "fn": _subsample_fn,
         "name": name,
         "args": (size, subsample_size),
         "kwargs": {"rng_key": None},
         "value": (
             None
             if (subsample_size is not None and size != subsample_size)
             else jnp.arange(size)
         ),
         "scale": 1.0,
         "cond_indep_stack": [],
     }
     apply_stack(msg)
     subsample = msg["value"]
     subsample_size = msg["args"][1]
     if subsample_size is not None and subsample_size != subsample.shape[0]:
         warnings.warn(
             "subsample_size does not match len(subsample), {} vs {}.".format(
                 subsample_size, len(subsample)
             )
             + " Did you accidentally use different subsample_size in the model and guide?",
             stacklevel=find_stack_level(),
         )
     cond_indep_stack = msg["cond_indep_stack"]
     occupied_dims = {f.dim for f in cond_indep_stack}
     if dim is None:
         new_dim = -1
         while new_dim in occupied_dims:
             new_dim -= 1
         dim = new_dim
     else:
         assert dim not in occupied_dims
     return dim, subsample
Esempio n. 21
0
    def __init__(
        self,
        model: Callable,
        posterior_samples: Optional[Dict] = None,
        *,
        guide: Optional[Callable] = None,
        params: Optional[Dict] = None,
        num_samples: Optional[int] = None,
        return_sites: Optional[List[str]] = None,
        infer_discrete: bool = False,
        parallel: bool = False,
        batch_ndims: Optional[int] = None,
    ):
        if posterior_samples is None and num_samples is None:
            raise ValueError(
                "Either posterior_samples or num_samples must be specified.")

        batch_ndims = (batch_ndims if batch_ndims is not None else
                       1 if guide is None else 0)

        posterior_samples = {} if posterior_samples is None else posterior_samples

        prototype_site = batch_shape = batch_size = None
        for name, sample in posterior_samples.items():
            if batch_shape is not None and sample.shape[:
                                                        batch_ndims] != batch_shape:
                raise ValueError(
                    f"Batch shapes at site {name} and {prototype_site} "
                    f"should be the same, but got "
                    f"{sample.shape[:batch_ndims]} and {batch_shape}")
            else:
                prototype_site = name
                batch_shape = sample.shape[:batch_ndims]
                batch_size = int(np.prod(batch_shape))
                if (num_samples is not None) and (num_samples != batch_size):
                    warnings.warn(
                        "Sample's batch dimension size {} is different from the "
                        "provided {} num_samples argument. Defaulting to {}.".
                        format(batch_size, num_samples, batch_size),
                        UserWarning,
                        stacklevel=find_stack_level(),
                    )
                num_samples = batch_size

        if num_samples is None:
            raise ValueError(
                "No sample sites in posterior samples to infer `num_samples`.")

        if batch_shape is None:
            batch_shape = (1, ) * (batch_ndims - 1) + (num_samples, )

        if return_sites is not None:
            assert isinstance(return_sites, (list, tuple, set))

        self.model = model
        self.posterior_samples = {} if posterior_samples is None else posterior_samples
        self.num_samples = num_samples
        self.guide = guide
        self.params = {} if params is None else params
        self.infer_discrete = infer_discrete
        self.return_sites = return_sites
        self.parallel = parallel
        self.batch_ndims = batch_ndims
        self._batch_shape = batch_shape
Esempio n. 22
0
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import warnings

from numpyro.ops import Vindex, vindex  # noqa: F401
from numpyro.util import find_stack_level

warnings.warn(
    "`indexing` module has been moved from `numpyro.contrib` to `numpyro.ops`."
    " Please import Vindex or vindex functions from `numpyro.ops.indexing`.",
    FutureWarning,
    stacklevel=find_stack_level(),
)