def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def test_syntax(): loc = np.random.uniform(-1.0, 1.0, ()) scale = np.random.uniform(0.5, 1.5, ()) data = np.random.uniform(-1.0, 1.0, ()) config = {"x": LocScaleReparam(), "y": LocScaleReparam()} def model(): x = numpyro.sample("x", dist.Normal(loc, scale)) y = numpyro.sample("y", dist.Normal(x, scale)) numpyro.sample("z", dist.Normal(y, scale), obs=data) # Context manager syntax. seed = numpyro.handlers.seed(rng_seed=0) trace = numpyro.handlers.trace() reparam = numpyro.handlers.reparam(config=config) with reparam, trace, seed: model() tr1 = trace.trace # Eager function syntax. seed = numpyro.handlers.seed(model, rng_seed=0) trace = numpyro.handlers.trace(seed) reparam = numpyro.handlers.reparam(trace, config=config) reparam() tr2 = trace.trace # Partial function syntax. seed = numpyro.handlers.seed(rng_seed=0) trace = numpyro.handlers.trace() reparam = numpyro.handlers.reparam(config=config) m = model m = seed(m) m = trace(m) m = reparam(m) m() tr3 = trace.trace # Decorator syntax. seed = numpyro.handlers.seed(rng_seed=0) trace = numpyro.handlers.trace() reparam = numpyro.handlers.reparam(config=config) @reparam @trace @seed def m(): return model() m() tr4 = trace.trace assert tr1.keys() == tr2.keys() == tr3.keys() == tr4.keys()
def item_difficulty(annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): eta = numpyro.sample( "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample( "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", annotations.shape[-1]): numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def test_loc_scale(dist_type, centered, shape, event_dim): loc = np.random.uniform(-1., 1., shape) scale = np.random.uniform(0.5, 1.5, shape) event_dim = min(event_dim, len(shape)) def model(loc, scale): with numpyro.plate_stack("plates", shape[:len(shape) - event_dim]): with numpyro.plate("particles", 10000): if "dist_type" == "Normal": numpyro.sample("x", dist.Normal(loc, scale).to_event(event_dim)) else: numpyro.sample( "x", dist.StudentT(10.0, loc, scale).to_event(event_dim)) def get_expected_probe(loc, scale): with numpyro.handlers.trace() as trace: with numpyro.handlers.seed(rng_seed=0): model(loc, scale) return get_moments(trace["x"]["value"]) if "dist_type" == "Normal": reparam = LocScaleReparam() else: reparam = LocScaleReparam(shape_params=["df"]) def get_actual_probe(loc, scale): with numpyro.handlers.trace() as trace: with numpyro.handlers.seed(rng_seed=0): with numpyro.handlers.reparam(config={"x": reparam}): model(loc, scale) return get_moments(trace["x"]["value"]) expected_probe = get_expected_probe(loc, scale) actual_probe = get_actual_probe(loc, scale) assert_allclose(actual_probe, expected_probe, atol=0.1) expected_grad = jacobian(get_expected_probe, argnums=(0, 1))(loc, scale) actual_grad = jacobian(get_actual_probe, argnums=(0, 1))(loc, scale) assert_allclose(actual_grad[0], expected_grad[0], atol=0.05) # loc grad assert_allclose(actual_grad[1], expected_grad[1], atol=0.05) # scale grad
def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.handlers import reparam from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam def model(dim=10): y = numpyro.sample("y", dist.Normal(0, 3)) numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2))) reparam_model = reparam(model, config={"x": LocScaleReparam(0)}) def run_inference(model, args, rng_key): kernel = NUTS(model) mcmc = MCMC( kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(rng_key) mcmc.print_summary(exclude_deterministic=False) return mcmc.get_samples()
def reparam_model(dim=10): y = numpyro.sample('y', dist.Normal(0, 3)) with numpyro.handlers.reparam(config={'x': LocScaleReparam(0)}): numpyro.sample('x', dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
def reparam_model(model): return reparam(model, config={'x': LocScaleReparam(0)})