def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def item_difficulty(annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): eta = numpyro.sample( "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample( "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", annotations.shape[-1]): numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = handlers.reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.handlers import reparam from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam def model(dim=10): y = numpyro.sample("y", dist.Normal(0, 3)) numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2))) reparam_model = reparam(model, config={"x": LocScaleReparam(0)}) def run_inference(model, args, rng_key): kernel = NUTS(model) mcmc = MCMC( kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(rng_key) mcmc.print_summary(exclude_deterministic=False) return mcmc.get_samples()
def run(self, rng_key, *args, **kwargs): """ Run the nested samplers and collect weighted samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: The arguments needed by the `model`. :param kwargs: The keyword arguments needed by the `model`. """ rng_sampling, rng_predictive = random.split(rng_key) # reparam the model so that latent sites have Uniform(0, 1) priors prototype_trace = trace(seed(self.model, rng_key)).get_trace(*args, **kwargs) param_names = [ site["name"] for site in prototype_trace.values() if site["type"] == "sample" and not site["is_observed"] and site["infer"].get("enumerate", "") != "parallel" ] deterministics = [ site["name"] for site in prototype_trace.values() if site["type"] == "deterministic" ] reparam_model = reparam( self.model, config={k: UniformReparam() for k in param_names}) # enable enumerate if needed has_enum = any(site["type"] == "sample" and site["infer"].get("enumerate", "") == "parallel" for site in prototype_trace.values()) if has_enum: from numpyro.contrib.funsor import enum, log_density as log_density_ max_plate_nesting = _guess_max_plate_nesting(prototype_trace) _validate_model(prototype_trace) reparam_model = enum(reparam_model, -max_plate_nesting - 1) else: log_density_ = log_density def loglik_fn(**params): return log_density_(reparam_model, args, kwargs, params)[0] # use NestedSampler with identity prior chain prior_chain = PriorChain() for name in param_names: prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape()) prior_chain.push(prior) # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some # quantity over posterior samples; it can be helpful to expose it in this wrapper ns = OrigNestedSampler( loglik_fn, prior_chain, sampler_name=self.sampler_name, sampler_kwargs={ "depth": self.depth, "num_slices": self.num_slices }, max_samples=self.max_samples, num_live_points=self.num_live_points, collect_samples=True, ) # some places of jaxns uses float64 and raises some warnings if the default dtype is # float32, so we suppress them here to avoid confusion with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*will be truncated to dtype float32.*") results = ns(rng_sampling, termination_frac=self.termination_frac) # transform base samples back to original domains # Here we only transform the first valid num_samples samples # NB: the number of weighted samples obtained from jaxns is results.num_samples # and only the first num_samples values of results.samples are valid. num_samples = results.num_samples samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples) predictive = Predictive(reparam_model, samples, return_sites=param_names + deterministics) samples = predictive(rng_predictive, *args, **kwargs) # replace base samples in jaxns results by transformed samples self._results = results._replace(samples=samples)
import matplotlib.pyplot as plt import numpy as np import numpyro import numpyro.distributions as dist from jax import random from numpyro import handlers, infer from numpyro.infer import reparam def model(dim: int = 10) -> None: y = numpyro.sample("y", dist.Normal(0, 3)) numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2))) reparam_model = handlers.reparam(model, config={"x": reparam.LocScaleReparam(0)}) def run_inference(model: Callable, rng_key: np.ndarray) -> Dict[str, jnp.ndarray]: kernel = infer.NUTS(model) mcmc = infer.MCMC(kernel, 1000, 1000, 1) mcmc.run(rng_key) return mcmc.get_samples() def _plot_results(samples: Dict[str, jnp.ndarray], reparam_samples: Dict[str, jnp.ndarray]) -> None: root = pathlib.Path("./data/neals")
def reparam_model(model): return reparam(model, config={'x': LocScaleReparam(0)})
def _auto_reparam(self) -> numpyro.handlers.reparam: """Automatically reparameterise circular parameters.""" return hdl.reparam( config={k: CircularReparam() for k in self.get_circ_var_names()})