def helpful_support_errors(site, raise_warnings=False): name = site["name"] support = getattr(site["fn"], "support", None) if isinstance(support, constraints.independent): support = support.base_constraint # Warnings if raise_warnings: if support is constraints.circular: msg = ( f"Continuous inference poorly handles circular sample site '{name}'. " + "Consider using VonMises distribution together with " + "a reparameterizer, e.g. " + f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})." ) warnings.warn(msg, UserWarning, stacklevel=find_stack_level()) # Exceptions try: yield except NotImplementedError as e: support_name = repr(support).lower() if "integer" in support_name or "boolean" in support_name: # TODO: mention enumeration when it is supported in SVI raise ValueError( f"Continuous inference cannot handle discrete sample site '{name}'." ) if "sphere" in support_name: raise ValueError( f"Continuous inference cannot handle spherical sample site '{name}'. " "Consider using ProjectedNormal distribution together with " "a reparameterizer, e.g. " f"numpyro.handlers.reparam(config={{'{name}': ProjectedNormalReparam()}})." ) raise e from None
def _load_higgs(num_datapoints): warnings.warn( "Higgs is a 2.6 GB dataset", stacklevel=find_stack_level(), ) _download(HIGGS) file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz") with io.TextIOWrapper(gzip.open(file_path, "rb")) as f: csv_reader = csv.reader(f, delimiter=",", quoting=csv.QUOTE_NONE) obs = [] data = [] for i, row in enumerate(csv_reader): obs.append(int(float(row[0]))) data.append([float(v) for v in row[1:]]) if num_datapoints and i > num_datapoints: break obs = np.stack(obs) data = np.stack(data) (n, ) = obs.shape return { "train": (data[:-(n // 20)], obs[:-(n // 20)]), "test": (data[-(n // 20):], obs[-(n // 20):]), } # standard split -500_000: as test
def init_to_median(site=None, num_samples=15): """ Initialize to the prior median. For priors with no `.sample` method implemented, we defer to the :func:`init_to_uniform` strategy. :param int num_samples: number of prior points to calculate median. """ if site is None: return partial(init_to_median, num_samples=num_samples) if (site["type"] == "sample" and not site["is_observed"] and not site["fn"].support.is_discrete): if site["value"] is not None: warnings.warn( f"init_to_median() skipping initialization of site '{site['name']}'" " which already stores a value.", stacklevel=find_stack_level(), ) return site["value"] rng_key = site["kwargs"].get("rng_key") sample_shape = site["kwargs"].get("sample_shape") try: samples = site["fn"](sample_shape=(num_samples, ) + sample_shape, rng_key=rng_key) return jnp.median(samples, axis=0) except NotImplementedError: return init_to_uniform(site)
def _inverse(self, y): warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", stacklevel=find_stack_level(), ) return y
def init_to_uniform(site=None, radius=2): """ Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ if site is None: return partial(init_to_uniform, radius=radius) if (site["type"] == "sample" and not site["is_observed"] and not site["fn"].support.is_discrete): if site["value"] is not None: warnings.warn( f"init_to_uniform() skipping initialization of site '{site['name']}'" " which already stores a value.", stacklevel=find_stack_level(), ) return site["value"] # XXX: we import here to avoid circular import from numpyro.infer.util import helpful_support_errors rng_key = site["kwargs"].get("rng_key") sample_shape = site["kwargs"].get("sample_shape") with helpful_support_errors(site): transform = biject_to(site["fn"].support) unconstrained_shape = transform.inverse_shape(site["fn"].shape()) unconstrained_samples = dist.Uniform(-radius, radius)( rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape) return transform(unconstrained_samples)
def event_dim(self): warnings.warn( "transform.event_dim is deprecated. Please use Transform.domain.event_dim to " "get input event dim or Transform.codomain.event_dim to get output event dim.", FutureWarning, stacklevel=find_stack_level(), ) return self.domain.event_dim
def __init__(self): warnings.warn( "PRNGIdentity distribution is deprecated. To get a random " "PRNG key, you can use `numpyro.prng_key()` instead.", FutureWarning, stacklevel=find_stack_level(), ) super(PRNGIdentity, self).__init__(event_shape=(2, ))
def __init__(self, domain=constraints.lower_cholesky): warnings.warn( "InvCholeskyTransform is deprecated. Please use CholeskyTransform" " or CorrMatrixCholeskyTransform instead.", FutureWarning, stacklevel=find_stack_level(), ) assert domain in [constraints.lower_cholesky, constraints.corr_cholesky] self.domain = domain
def get_param_store(): warnings.warn( "A limited parameter store is provided for compatibility with Pyro. " "Value of SVI parameters should be obtained via SVI.get_params() method.", category=UnsupportedAPIWarning, stacklevel=find_stack_level(), ) # Return an empty dict for compatibility return _PARAM_STORE
def init(self, *args, **kwargs): warnings.warn( "Importing distributions from numpyro.contrib.tfp.distributions is " "deprecated. You should import distributions directly from " "tensorflow_probability.substrates.jax.distributions instead.", FutureWarning, stacklevel=find_stack_level(), ) self.tfp_dist = tfd_class(*args, **kwargs)
def __init__( self, sampler, *, num_warmup, num_samples, num_chains=1, thinning=1, postprocess_fn=None, chain_method="parallel", progress_bar=True, jit_model_args=False, ): self.sampler = sampler self._sample_field = sampler.sample_field self._default_fields = sampler.default_fields self.num_warmup = num_warmup self.num_samples = num_samples self.num_chains = num_chains if not isinstance(thinning, int) or thinning < 1: raise ValueError("thinning must be a positive integer") self.thinning = thinning self.postprocess_fn = postprocess_fn if chain_method not in ["parallel", "vectorized", "sequential"]: raise ValueError( "Only supporting the following methods to draw chains:" ' "sequential", "parallel", or "vectorized"' ) if chain_method == "parallel" and local_device_count() < self.num_chains: chain_method = "sequential" warnings.warn( "There are not enough devices to run parallel chains: expected {} but got {}." " Chains will be drawn sequentially. If you are running MCMC in CPU," " consider using `numpyro.set_host_device_count({})` at the beginning" " of your program. You can double-check how many devices are available in" " your system using `jax.local_device_count()`.".format( self.num_chains, local_device_count(), self.num_chains ), stacklevel=find_stack_level(), ) self.chain_method = chain_method self.progress_bar = progress_bar if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ: self.progress_bar = False self._jit_model_args = jit_model_args self._states = None self._states_flat = None # HMCState returned by last run self._last_state = None # HMCState returned by last warmup self._warmup_state = None # HMCState returned by hmc.init_kernel self._init_state_cache = {} self._cache = {} self._collection_params = {} self._set_collection_params()
def multinomial_goodness_of_fit(probs, counts, *, total_count=None, plot=False): """ Pearson's chi^2 test, on possibly truncated data. https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test :param numpy.ndarray probs: Vector of probabilities. :param numpy.ndarray counts: Vector of counts. :param int total_count: Optional total count in case data is truncated, otherwise None. :param bool plot: Whether to print a histogram. Defaults to False. :returns: p-value of truncated multinomial sample. :rtype: float """ probs = jax.lax.stop_gradient(probs) assert len(probs.shape) == 1 assert probs.shape == counts.shape if total_count is None: truncated = False total_count = int(counts.sum()) else: truncated = True assert total_count >= counts.sum() if plot: print_histogram(probs, counts) chi_squared = 0 dof = 0 for p, c in zip(probs.tolist(), counts.tolist()): if abs(p - 1) < 1e-8: return 1 if c == total_count else 0 assert p < 1, f"bad probability: {p:g}" if p > 0: mean = total_count * p variance = total_count * p * (1 - p) if not (variance > 1): raise InvalidTest( "Goodness of fit is inaccurate; use more samples") chi_squared += (c - mean)**2 / variance dof += 1 else: warnings.warn( "Zero probability in goodness-of-fit test", stacklevel=find_stack_level(), ) if c > 0: return math.inf if not truncated: dof -= 1 survival = _chi2sf(chi_squared, dof) return survival
def _validate_sample(self, value): mask = self.support(value) if not_jax_tracer(mask): if not np.all(mask): warnings.warn( "Out-of-support values provided to log prob method. " "The value argument should be within the support.", stacklevel=find_stack_level(), ) return mask
def __init__(self, fn=None, trace=None, guide_trace=None): if guide_trace is not None: warnings.warn( "`guide_trace` argument is deprecated. Please replace it by `trace`.", FutureWarning, stacklevel=find_stack_level(), ) if guide_trace is not None: trace = guide_trace assert trace is not None self.trace = trace super(replay, self).__init__(fn)
def _inverse(self, y): leading_dims = [ v.shape[0] if jnp.ndim(v) > 0 else 0 for v in tree_flatten(y)[0] ] d0 = leading_dims[0] not_scalar = d0 > 0 or len(leading_dims) > 1 if not_scalar and all(d == d0 for d in leading_dims[1:]): warnings.warn( "UnpackTransform.inv might lead to an unexpected behavior because it" " cannot transform a batch of unpacked arrays.", stacklevel=find_stack_level(), ) return ravel_pytree(y)[0]
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} mutable_state = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site["type"] == "param": constraint = site["kwargs"].pop("constraint", constraints.real) with helpful_support_errors(site): transform = biject_to(constraint) inv_transforms[site["name"]] = transform params[site["name"]] = transform.inv(site["value"]) elif site["type"] == "mutable": mutable_state[site["name"]] = site["value"] elif (site["type"] == "sample" and (not site["is_observed"]) and site["fn"].support.is_discrete and not self.loss.can_infer_discrete): s_name = type(self.loss).__name__ warnings.warn( f"Currently, SVI with {s_name} loss does not support models with discrete latent variables", stacklevel=find_stack_level(), ) if not mutable_state: mutable_state = None self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params, mutable_state = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), (params, mutable_state), ) return SVIState(self.optim.init(params), mutable_state, rng_key)
def _get_model_transforms(model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs model_trace = trace(model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of deterministic sites replay_model = False has_enumerate_support = False for k, v in model_trace.items(): if v["type"] == "sample" and not v["is_observed"]: if v["fn"].support.is_discrete: enum_type = v["infer"].get("enumerate") if enum_type is not None and (enum_type != "parallel"): raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate marked 'parallel'. But the site {k} is marked" f" as '{enum_type}'.") has_enumerate_support = True if not v["fn"].has_enumerate_support: dist_name = type(v["fn"]).__name__ raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate support. But the {dist_name} distribution at" f" site {k} does not have enumerate support.") if enum_type is None: warnings.warn( "Some algorithms will automatically enumerate the discrete" f" latent site {k} of your model. In the future," " enumerated sites need to be marked with" " `infer={'enumerate': 'parallel'}`.", FutureWarning, stacklevel=find_stack_level(), ) else: support = v["fn"].support with helpful_support_errors(v, raise_warnings=True): inv_transforms[k] = biject_to(support) # XXX: the following code filters out most situations with dynamic supports args = () if isinstance(support, constraints._GreaterThan): args = ("lower_bound", ) elif isinstance(support, constraints._Interval): args = ("lower_bound", "upper_bound") for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True elif v["type"] == "deterministic": replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace
def _check_mean_field_requirement(model_trace, guide_trace): """ Checks that the guide and model sample sites are ordered identically. This is sufficient but not necessary for correctness. """ model_sites = [ name for name, site in model_trace.items() if site["type"] == "sample" and name in guide_trace ] guide_sites = [ name for name, site in guide_trace.items() if site["type"] == "sample" and name in model_trace ] assert set(model_sites) == set(guide_sites) if model_sites != guide_sites: warnings.warn( "Failed to verify mean field restriction on the guide. " "To eliminate this warning, ensure model and guide sites " "occur in the same order.\n" + "Model sites:\n " + "\n ".join(model_sites) + "Guide sites:\n " + "\n ".join(guide_sites), stacklevel=find_stack_level(), ),
def process_message(self, msg): if msg["type"] != "sample": return if (msg.get("_intervener_id", None) != self._intervener_id and self.data.get(msg["name"]) is not None): if msg.get("_intervener_id", None) is not None: warnings.warn( "Attempting to intervene on variable {} multiple times," "this is almost certainly incorrect behavior".format( msg["name"]), RuntimeWarning, stacklevel=find_stack_level(), ) msg["_intervener_id"] = self._intervener_id # split node, avoid reapplying self recursively to new node new_msg = msg.copy() apply_stack(new_msg) intervention = self.data.get(msg["name"]) msg["name"] = msg["name"] + "__CF" # mangle old name msg["value"] = intervention msg["is_observed"] = True msg["stop"] = True
def _subsample(name, size, subsample_size, dim): msg = { "type": "plate", "fn": _subsample_fn, "name": name, "args": (size, subsample_size), "kwargs": {"rng_key": None}, "value": ( None if (subsample_size is not None and size != subsample_size) else jnp.arange(size) ), "scale": 1.0, "cond_indep_stack": [], } apply_stack(msg) subsample = msg["value"] subsample_size = msg["args"][1] if subsample_size is not None and subsample_size != subsample.shape[0]: warnings.warn( "subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample) ) + " Did you accidentally use different subsample_size in the model and guide?", stacklevel=find_stack_level(), ) cond_indep_stack = msg["cond_indep_stack"] occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: new_dim = -1 while new_dim in occupied_dims: new_dim -= 1 dim = new_dim else: assert dim not in occupied_dims return dim, subsample
def __init__( self, model: Callable, posterior_samples: Optional[Dict] = None, *, guide: Optional[Callable] = None, params: Optional[Dict] = None, num_samples: Optional[int] = None, return_sites: Optional[List[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, ): if posterior_samples is None and num_samples is None: raise ValueError( "Either posterior_samples or num_samples must be specified.") batch_ndims = (batch_ndims if batch_ndims is not None else 1 if guide is None else 0) posterior_samples = {} if posterior_samples is None else posterior_samples prototype_site = batch_shape = batch_size = None for name, sample in posterior_samples.items(): if batch_shape is not None and sample.shape[: batch_ndims] != batch_shape: raise ValueError( f"Batch shapes at site {name} and {prototype_site} " f"should be the same, but got " f"{sample.shape[:batch_ndims]} and {batch_shape}") else: prototype_site = name batch_shape = sample.shape[:batch_ndims] batch_size = int(np.prod(batch_shape)) if (num_samples is not None) and (num_samples != batch_size): warnings.warn( "Sample's batch dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}.". format(batch_size, num_samples, batch_size), UserWarning, stacklevel=find_stack_level(), ) num_samples = batch_size if num_samples is None: raise ValueError( "No sample sites in posterior samples to infer `num_samples`.") if batch_shape is None: batch_shape = (1, ) * (batch_ndims - 1) + (num_samples, ) if return_sites is not None: assert isinstance(return_sites, (list, tuple, set)) self.model = model self.posterior_samples = {} if posterior_samples is None else posterior_samples self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params self.infer_discrete = infer_discrete self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims self._batch_shape = batch_shape
# Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 import warnings from numpyro.ops import Vindex, vindex # noqa: F401 from numpyro.util import find_stack_level warnings.warn( "`indexing` module has been moved from `numpyro.contrib` to `numpyro.ops`." " Please import Vindex or vindex functions from `numpyro.ops.indexing`.", FutureWarning, stacklevel=find_stack_level(), )