Exemplo n.º 1
0
def logistic_random_effects(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Exemplo n.º 2
0
def test_syntax():
    loc = np.random.uniform(-1.0, 1.0, ())
    scale = np.random.uniform(0.5, 1.5, ())
    data = np.random.uniform(-1.0, 1.0, ())
    config = {"x": LocScaleReparam(), "y": LocScaleReparam()}

    def model():
        x = numpyro.sample("x", dist.Normal(loc, scale))
        y = numpyro.sample("y", dist.Normal(x, scale))
        numpyro.sample("z", dist.Normal(y, scale), obs=data)

    # Context manager syntax.
    seed = numpyro.handlers.seed(rng_seed=0)
    trace = numpyro.handlers.trace()
    reparam = numpyro.handlers.reparam(config=config)
    with reparam, trace, seed:
        model()
    tr1 = trace.trace

    # Eager function syntax.
    seed = numpyro.handlers.seed(model, rng_seed=0)
    trace = numpyro.handlers.trace(seed)
    reparam = numpyro.handlers.reparam(trace, config=config)
    reparam()
    tr2 = trace.trace

    # Partial function syntax.
    seed = numpyro.handlers.seed(rng_seed=0)
    trace = numpyro.handlers.trace()
    reparam = numpyro.handlers.reparam(config=config)
    m = model
    m = seed(m)
    m = trace(m)
    m = reparam(m)
    m()
    tr3 = trace.trace

    # Decorator syntax.
    seed = numpyro.handlers.seed(rng_seed=0)
    trace = numpyro.handlers.trace()
    reparam = numpyro.handlers.reparam(config=config)

    @reparam
    @trace
    @seed
    def m():
        return model()

    m()
    tr4 = trace.trace

    assert tr1.keys() == tr2.keys() == tr3.keys() == tr4.keys()
Exemplo n.º 3
0
def item_difficulty(annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta",
            dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample(
            "Chi",
            dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta",
                                   dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y",
                           dist.Categorical(logits=theta),
                           obs=annotations)
Exemplo n.º 4
0
def hierarchical_dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Exemplo n.º 5
0
def test_loc_scale(dist_type, centered, shape, event_dim):
    loc = np.random.uniform(-1., 1., shape)
    scale = np.random.uniform(0.5, 1.5, shape)
    event_dim = min(event_dim, len(shape))

    def model(loc, scale):
        with numpyro.plate_stack("plates", shape[:len(shape) - event_dim]):
            with numpyro.plate("particles", 10000):
                if "dist_type" == "Normal":
                    numpyro.sample("x",
                                   dist.Normal(loc, scale).to_event(event_dim))
                else:
                    numpyro.sample(
                        "x",
                        dist.StudentT(10.0, loc, scale).to_event(event_dim))

    def get_expected_probe(loc, scale):
        with numpyro.handlers.trace() as trace:
            with numpyro.handlers.seed(rng_seed=0):
                model(loc, scale)
        return get_moments(trace["x"]["value"])

    if "dist_type" == "Normal":
        reparam = LocScaleReparam()
    else:
        reparam = LocScaleReparam(shape_params=["df"])

    def get_actual_probe(loc, scale):
        with numpyro.handlers.trace() as trace:
            with numpyro.handlers.seed(rng_seed=0):
                with numpyro.handlers.reparam(config={"x": reparam}):
                    model(loc, scale)
        return get_moments(trace["x"]["value"])

    expected_probe = get_expected_probe(loc, scale)
    actual_probe = get_actual_probe(loc, scale)
    assert_allclose(actual_probe, expected_probe, atol=0.1)

    expected_grad = jacobian(get_expected_probe, argnums=(0, 1))(loc, scale)
    actual_grad = jacobian(get_actual_probe, argnums=(0, 1))(loc, scale)
    assert_allclose(actual_grad[0], expected_grad[0], atol=0.05)  # loc grad
    assert_allclose(actual_grad[1], expected_grad[1], atol=0.05)  # scale grad
Exemplo n.º 6
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
Exemplo n.º 7
0
from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam


def model(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))


reparam_model = reparam(model, config={"x": LocScaleReparam(0)})


def run_inference(model, args, rng_key):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key)
    mcmc.print_summary(exclude_deterministic=False)
    return mcmc.get_samples()
Exemplo n.º 8
0
def reparam_model(dim=10):
    y = numpyro.sample('y', dist.Normal(0, 3))
    with numpyro.handlers.reparam(config={'x': LocScaleReparam(0)}):
        numpyro.sample('x', dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
Exemplo n.º 9
0
def reparam_model(model):
    return reparam(model, config={'x': LocScaleReparam(0)})