Example #1
0
def test_chain_jit_args_smoke(chain_method, compile_args):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
        numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    data1 = dist.Categorical(jnp.array([0.1, 0.6,
                                        0.3])).sample(random.PRNGKey(1),
                                                      (50, ))
    data2 = dist.Categorical(jnp.array([0.2, 0.4,
                                        0.4])).sample(random.PRNGKey(1),
                                                      (50, ))
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=2,
        num_samples=5,
        num_chains=2,
        chain_method=chain_method,
        jit_model_args=compile_args,
    )
    mcmc.warmup(random.PRNGKey(0), data1)
    mcmc.run(random.PRNGKey(1), data1)
    # this should be fast if jit_model_args=True
    mcmc.run(random.PRNGKey(2), data2)
Example #2
0
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[
        0], emission_prior.shape[0]
    transition_prob = numpyro.sample(
        'transition_prob',
        dist.Dirichlet(
            np.broadcast_to(transition_prior,
                            (num_categories, num_categories))))
    emission_prob = numpyro.sample(
        'emission_prob',
        dist.Dirichlet(
            np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    numpyro.sample('supervised_categories',
                   dist.Categorical(
                       transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words',
                   dist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    numpyro.factor('forward_log_prob', log_prob)
Example #3
0
def simulate_data(rng_key, num_categories, num_words, num_supervised_data,
                  num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = np.ones(num_categories)
    emission_prior = np.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(
        key=rng_key_transition, sample_shape=(num_categories, ))
    emission_prob = dist.Dirichlet(emission_prior).sample(
        key=rng_key_emission, sample_shape=(num_categories, ))

    start_prob = np.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(
            rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(
                key=rng_key_transition)
        else:
            category = dist.Categorical(
                transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(
            emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = np.stack(categories), np.stack(words)
    supervised_categories = categories[:num_supervised_data]
    supervised_words = words[:num_supervised_data]
    unsupervised_words = words[num_supervised_data:]
    return (transition_prior, emission_prior, transition_prob, emission_prob,
            supervised_categories, supervised_words, unsupervised_words)
Example #4
0
 def transition_fn(carry, y):
     x, w = carry
     x = numpyro.sample("x", dist.Categorical(probs_x[x]))
     w = numpyro.sample("w", dist.Categorical(probs_w[w]))
     numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y)
     # also test if scan's `ys` are recorded corrected
     return (x, w), x
Example #5
0
def hierarchical_dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            # non-centered parameterization
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
            # pad 0 to the last item
            beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Example #6
0
def logistic_random_effects(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c", dist.Categorical(pi))

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :] - theta
            numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
Example #7
0
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
    transition_prob = sample('transition_prob', dist.Dirichlet(
        np.broadcast_to(transition_prior, (num_categories, num_categories))))
    emission_prob = sample('emission_prob', dist.Dirichlet(
        np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
           obs=supervised_categories[1:])
    sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
           obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    # NB: This is a trick (uses an invalid value `0` of Multinomial distribution) to add an
    # additional term to potential energy.
    return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0)
Example #8
0
def dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators, dim=-2):
        with numpyro.plate("class", num_classes):
            beta = numpyro.sample("beta",
                                  dist.Dirichlet(jnp.ones(num_classes)))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with numpyro.plate("position", num_positions):
            numpyro.sample("y",
                           dist.Categorical(Vindex(beta)[positions, c, :]),
                           obs=annotations)
Example #9
0
def item_difficulty(annotations):
    """
    This model corresponds to the plate diagram in Figure 5 of reference [1].
    """
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("class", num_classes):
        eta = numpyro.sample(
            "eta",
            dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        chi = numpyro.sample(
            "Chi",
            dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample("c",
                           dist.Categorical(pi),
                           infer={"enumerate": "parallel"})

        with handlers.reparam(config={"theta": LocScaleReparam(0)}):
            theta = numpyro.sample("theta",
                                   dist.Normal(eta[c], chi[c]).to_event(1))
            theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

        with numpyro.plate("position", annotations.shape[-1]):
            numpyro.sample("y",
                           dist.Categorical(logits=theta),
                           obs=annotations)
Example #10
0
 def model(z1=None, z2=None):
     p = numpyro.param("p", np.array([0.25, 0.75]))
     loc = numpyro.param("loc", jnp.array([-1.0, 1.0]))
     z1 = numpyro.sample("z1", dist.Categorical(p), obs=z1)
     with numpyro.plate("data[0]", 3):
         numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
     with numpyro.plate("data[1]", 2):
         z2 = numpyro.sample("z2", dist.Categorical(p), obs=z2)
         numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Example #11
0
    def model(data_x, data_w):
        x = w = 0
        for i, y in markov(enumerate(data_x)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y)

        for i, y in markov(enumerate(data_w)):
            w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)
Example #12
0
 def transition_fn(carry, y):
     w_prev, x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
             x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
             with numpyro.plate("tones", data_dim, dim=-1) as tones:
                 numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
     return (w, x, t + 1), None
Example #13
0
 def model(z1=None, z2=None):
     p = numpyro.param("p", jnp.array([[0.25, 0.75], [0.1, 0.9]]))
     loc = numpyro.param("loc", jnp.array([-1.0, 1.0]))
     z1 = numpyro.sample("z1", dist.Categorical(p[0]), obs=z1)
     z2 = numpyro.sample("z2", dist.Categorical(p[z1]), obs=z2)
     logger.info("z1.shape = {}".format(z1.shape))
     logger.info("z2.shape = {}".format(z2.shape))
     with numpyro.plate("data", 3):
         numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
         numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Example #14
0
 def transition_fn(carry, y):
     w_prev, x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
             x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
             # Note the broadcasting tricks here: to index probs_y on tensors x and y,
             # we also need a final tensor for the tones dimension. This is conveniently
             # provided by the plate associated with that dimension.
             with numpyro.plate("tones", data_dim, dim=-1) as tones:
                 numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
     return (w, x, t + 1), None
Example #15
0
def simulate_data(
    rng_key: np.ndarray,
    num_categories: int,
    num_words: int,
    num_supervised: int,
    num_unsupservised: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:

    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(
        rng_key_transition, sample_shape=(num_categories, ))
    emission_prob = dist.Dirichlet(emission_prior).sample(
        rng_key_emission, sample_shape=(num_categories, ))

    start_prob = jnp.repeat(1.0 / num_categories, num_categories)
    category = 0
    categories = []
    words = []

    for t in range(num_supervised + num_unsupservised):
        rng_key, rng_key_transition, rng_key_emission = random.split(
            rng_key, 3)
        if t == 0 or t == num_supervised:
            category = dist.Categorical(start_prob).sample(rng_key_transition)
        else:
            category = dist.Categorical(
                transition_prob[category]).sample(rng_key_transition)

        word = dist.Categorical(
            emission_prob[category]).sample(rng_key_emission)
        categories.append(category)
        words.append(word)

    # Split data into supervised and unsupervised
    categories = jnp.stack(categories)
    words = jnp.stack(words)

    supervised_categories = categories[:num_supervised]
    supervised_words = words[:num_supervised]
    unsupervised_words = words[num_supervised:]

    return (
        transition_prob,
        emission_prob,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
Example #16
0
def model2():

    data = [np.array([-1.0, -1.0, 0.0]), np.array([-1.0, 1.0])]
    p = numpyro.param("p", np.array([0.25, 0.75]))
    loc = numpyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1))
    # FIXME results in infinite loop in transformeddist_to_funsor.
    # scale = numpyro.sample("scale", dist.LogNormal(0, 1))
    z1 = numpyro.sample("z1", dist.Categorical(p))
    scale = numpyro.sample("scale", dist.LogNormal(jnp.array([0.0, 1.0])[z1], 1))
    with numpyro.plate("data[0]", 3):
        numpyro.sample("x1", dist.Normal(loc[z1], scale), obs=data[0])
    with numpyro.plate("data[1]", 2):
        z2 = numpyro.sample("z2", dist.Categorical(p))
        numpyro.sample("x2", dist.Normal(loc[z2], scale), obs=data[1])
Example #17
0
 def transition_fn(x, y):
     probs = transition_probs[x]
     x = numpyro.sample("x", dist.Categorical(probs))
     with numpyro.plate("D", D, dim=-1):
         w = numpyro.sample("w", dist.Bernoulli(0.6))
         numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
     return x, None
Example #18
0
 def gmm(data):
     mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
     with numpyro.plate("num_clusters", K, dim=-1):
         cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.))
     with numpyro.plate("data", data.shape[0], dim=-1):
         assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions))
         numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)
Example #19
0
 def model(data):
     x = None
     for i, y in markov(enumerate(data)):
         probs = init_probs if x is None else transition_probs[x]
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
         numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
     return x
Example #20
0
    def __call__(self):
        assignment = numpyro.sample("assignment",
                                    dist.Categorical(self.weights))

        loc = self.loc[assignment]
        cov = self.cov[assignment]

        nu_max = numpyro.sample("nu_max", self.nu_max)
        log_nu_max = jnp.log10(nu_max)

        teff = numpyro.sample("teff", self.teff)

        loc0101 = loc[0:2]
        cov0101 = jnp.array([[cov[0, 0], cov[0, 1]], [cov[1, 0], cov[1, 1]]])

        L = jax.scipy.linalg.cho_factor(cov0101, lower=True)
        A = jax.scipy.linalg.cho_solve(L,
                                       jnp.array([log_nu_max, teff]) - loc0101)

        loc2323 = loc[2:]
        cov2323 = jnp.array([[cov[2, 2], cov[2, 3]], [cov[3, 2], cov[3, 3]]])

        cov0123 = jnp.array([[cov[0, 2], cov[1, 2]], [cov[0, 3], cov[1, 3]]])
        v = jax.scipy.linalg.cho_solve(L, cov0123.T)

        cond_loc = loc2323 + jnp.dot(cov0123, A)
        cond_cov = (
            cov2323 - jnp.dot(cov0123, v) +
            self.noise * jnp.eye(2)  # Add white noise
        )
        numpyro.sample("log_tau", dist.MultivariateNormal(cond_loc, cond_cov))
Example #21
0
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon",
                                 dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample(
            "c",
            dist.DiscreteUniform(0, num_classes - 1),
            infer={"enumerate": "parallel"},
        )

        with numpyro.plate("position", num_positions):
            s = numpyro.sample(
                "s",
                dist.Bernoulli(1 - theta[positions]),
                infer={"enumerate": "parallel"},
            )
            probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes),
                              epsilon[positions])
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)
Example #22
0
def test_dirichlet_categorical(algo, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    samples = mcmc(warmup_steps,
                   num_samples,
                   init_params,
                   constrain_fn=constrain_fn,
                   progbar=False,
                   print_summary=False,
                   potential_fn=potential_fn,
                   algo=algo,
                   trajectory_length=1.,
                   dense_mass=dense_mass)
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Example #23
0
def test_gaussian_mixture_model():
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        with numpyro.plate("num_clusters", K, dim=-1):
            cluster_means = numpyro.sample("cluster_means",
                                           dist.Normal(jnp.arange(K), 1.0))
        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments",
                                         dist.Categorical(mix_proportions))
            numpyro.sample("obs",
                           dist.Normal(cluster_means[assignments], 1.0),
                           obs=data)

    true_cluster_means = jnp.array([1.0, 5.0, 10.0])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(
        random.PRNGKey(0), (N, ))
    data = dist.Normal(true_cluster_means[cluster_assignments],
                       1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm, init_strategy=init_to_median)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["phi"].mean(0).sort(),
                    true_mix_proportions,
                    atol=0.05)
    assert_allclose(samples["cluster_means"].mean(0).sort(),
                    true_cluster_means,
                    atol=0.2)
Example #24
0
def test_missing_plate(monkeypatch):
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        # plate/to_event is missing here
        cluster_means = numpyro.sample("cluster_means",
                                       dist.Normal(jnp.arange(K), 1.0))

        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments",
                                         dist.Categorical(mix_proportions))
            numpyro.sample("obs",
                           dist.Normal(cluster_means[assignments], 1.0),
                           obs=data)

    true_cluster_means = jnp.array([1.0, 5.0, 10.0])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(
        random.PRNGKey(0), (N, ))
    data = dist.Normal(true_cluster_means[cluster_assignments],
                       1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    with pytest.raises(AssertionError, match="Missing plate statement"):
        mcmc.run(random.PRNGKey(2), data)

    monkeypatch.setattr(numpyro.infer.util, "_validate_model",
                        lambda model_trace: None)
    with pytest.raises(Exception):
        mcmc.run(random.PRNGKey(2), data)
    assert len(_PYRO_STACK) == 0
Example #25
0
 def per_component_fun(j):
     log_prob_x_zj = jnp.sum(dist.Normal(mus[j], sigs[j]).log_prob(obs), axis=1).flatten()
     assert(jnp.atleast_1d(log_prob_x_zj).shape == (N,))
     log_prob_zj = dist.Categorical(pis_prior).log_prob(j)
     log_prob = log_prob_x_zj + log_prob_zj
     assert(jnp.atleast_1d(log_prob).shape == (N,))
     return log_prob
Example #26
0
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    num_warmup, num_samples = 100, 20000

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
        numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    if kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(model,
                            trajectory_length=1.0,
                            dense_mass=dense_mass)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p_latent"].dtype == jnp.float64
Example #27
0
def test_missing_plate():
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        # plate/to_event is missing here
        cluster_means = numpyro.sample("cluster_means",
                                       dist.Normal(jnp.arange(K), 1.))

        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments",
                                         dist.Categorical(mix_proportions))
            numpyro.sample("obs",
                           dist.Normal(cluster_means[assignments], 1.),
                           obs=data)

    true_cluster_means = jnp.array([1., 5., 10.])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(
        random.PRNGKey(0), (N, ))
    data = dist.Normal(true_cluster_means[cluster_assignments],
                       1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    with pytest.raises(AssertionError, match="Missing plate statement"):
        mcmc.run(random.PRNGKey(2), data)
Example #28
0
def test_dirichlet_categorical(algo, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = sample('p_latent', dist.Dirichlet(concentration))
        sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps,
                            progbar=False,
                            dense_mass=dense_mass)
    hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
                              transform=lambda x: constrain_fn(x.z),
                              progbar=False)
    assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert hmc_states['p_latent'].dtype == np.float64
Example #29
0
 def model(z=None):
     p = numpyro.param("p", np.array([0.75, 0.25]))
     iz = numpyro.sample("z", dist.Categorical(p), obs=z)
     z = jnp.array([0.0, 1.0])[iz]
     logger.info("z.shape = {}".format(z.shape))
     with numpyro.plate("data", 3):
         numpyro.sample("x", dist.Normal(z, 1.0), obs=data)
Example #30
0
 def sample_single(single_key):
     vals_rng_key, pis_rng_key = jax.random.split(single_key, 2)
     z = dist.Categorical(self._pis).sample(pis_rng_key)
     rng_keys = jax.random.split(vals_rng_key, len(self.dists))
     vals = [dbn.sample(rng_keys[feat_idx])[z] \
             for feat_idx, dbn in enumerate(self.dists)]
     return jnp.stack(vals).T, z