def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def transition_fn(x, y): probs = transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D", D, dim=-1): w = numpyro.sample("w", dist.Bernoulli(0.6)) numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y) return x, None
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def test_value(x_shape, i_shape, j_shape, event_shape): x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape))) i = dist.Categorical(jnp.ones((5,))).sample(random.PRNGKey(1), i_shape) j = dist.Categorical(jnp.ones((6,))).sample(random.PRNGKey(2), j_shape) if event_shape: actual = Vindex(x)[..., i, j, :] else: actual = Vindex(x)[..., i, j] shape = lax.broadcast_shapes(x_shape, i_shape, j_shape) x = jnp.broadcast_to(x, shape + (5, 6) + event_shape) i = jnp.broadcast_to(i, shape) j = jnp.broadcast_to(j, shape) expected = np.empty(shape + event_shape, dtype=x.dtype) for ind in itertools.product(*map(range, shape)) if shape else [()]: expected[ind] = x[ind + (i[ind].item(), j[ind].item())] assert jnp.all(actual == jnp.array(expected, dtype=x.dtype))
def model(data): x = 0 D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): probs = transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D_plate: w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6)) numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t)) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None
def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None
def test_hmm_example(prev_enum_dim, curr_enum_dim): hidden_dim = 8 probs_x = jnp.array(np.random.rand(hidden_dim, hidden_dim, hidden_dim)) x_prev = jnp.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - prev_enum_dim)) x_curr = jnp.arange(hidden_dim).reshape((-1,) + (1,) * (-1 - curr_enum_dim)) expected = probs_x[ x_prev.reshape(x_prev.shape + (1,)), x_curr.reshape(x_curr.shape + (1,)), jnp.arange(hidden_dim), ] actual = Vindex(probs_x)[x_prev, x_curr, :] assert jnp.all(actual == expected)
def dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 2 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes))) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) # here we use Vindex to allow broadcasting for the second index `c` # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex with numpyro.plate("position", num_positions): numpyro.sample("y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations)