Beispiel #1
0
def logistic_random_effects(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Beispiel #2
0
 def transition_fn(x, y):
     probs = transition_probs[x]
     x = numpyro.sample("x", dist.Categorical(probs))
     with numpyro.plate("D", D, dim=-1):
         w = numpyro.sample("w", dist.Bernoulli(0.6))
         numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
     return x, None
Beispiel #3
0
def hierarchical_dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Beispiel #4
0
def test_value(x_shape, i_shape, j_shape, event_shape):
    x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape)))
    i = dist.Categorical(jnp.ones((5,))).sample(random.PRNGKey(1), i_shape)
    j = dist.Categorical(jnp.ones((6,))).sample(random.PRNGKey(2), j_shape)
    if event_shape:
        actual = Vindex(x)[..., i, j, :]
    else:
        actual = Vindex(x)[..., i, j]

    shape = lax.broadcast_shapes(x_shape, i_shape, j_shape)
    x = jnp.broadcast_to(x, shape + (5, 6) + event_shape)
    i = jnp.broadcast_to(i, shape)
    j = jnp.broadcast_to(j, shape)
    expected = np.empty(shape + event_shape, dtype=x.dtype)
    for ind in itertools.product(*map(range, shape)) if shape else [()]:
        expected[ind] = x[ind + (i[ind].item(), j[ind].item())]
    assert jnp.all(actual == jnp.array(expected, dtype=x.dtype))
Beispiel #5
0
 def model(data):
     x = 0
     D_plate = numpyro.plate("D", D, dim=-1)
     for i, y in markov(enumerate(data)):
         probs = transition_probs[x]
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
         with D_plate:
             w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6))
             numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
Beispiel #6
0
 def transition_fn(carry, y):
     x_prev, x_curr, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             probs_x_t = Vindex(probs_x)[x_prev, x_curr]
             x_prev, x_curr = x_curr, numpyro.sample(
                 "x", dist.Categorical(probs_x_t))
             with numpyro.plate("tones", data_dim, dim=-1):
                 probs_y_t = probs_y[x_curr.squeeze(-1)]
                 numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
     return (x_prev, x_curr, t + 1), None
Beispiel #7
0
 def transition_fn(carry, y):
     w_prev, x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
             x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
             with numpyro.plate("tones", data_dim, dim=-1) as tones:
                 numpyro.sample("y",
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=y)
     return (w, x, t + 1), None
Beispiel #8
0
def test_hmm_example(prev_enum_dim, curr_enum_dim):
    hidden_dim = 8
    probs_x = jnp.array(np.random.rand(hidden_dim, hidden_dim, hidden_dim))
    x_prev = jnp.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - prev_enum_dim))
    x_curr = jnp.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - curr_enum_dim))

    expected = probs_x[
        x_prev.reshape(x_prev.shape + (1,)),
        x_curr.reshape(x_curr.shape + (1,)),
        jnp.arange(hidden_dim),
    ]

    actual = Vindex(probs_x)[x_prev, x_curr, :]
    assert jnp.all(actual == expected)
Beispiel #9
0
def dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with numpyro.plate("position", num_positions):
            numpyro.sample("y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations)