def test_chain_jit_args_smoke(chain_method, compile_args): def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration)) numpyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent data1 = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (50, )) data2 = dist.Categorical(jnp.array([0.2, 0.4, 0.4])).sample(random.PRNGKey(1), (50, )) kernel = NUTS(model) mcmc = MCMC( kernel, num_warmup=2, num_samples=5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args, ) mcmc.warmup(random.PRNGKey(0), data1) mcmc.run(random.PRNGKey(1), data1) # this should be fast if jit_model_args=True mcmc.run(random.PRNGKey(2), data2)
def semi_supervised_hmm(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words): num_categories, num_words = transition_prior.shape[ 0], emission_prior.shape[0] transition_prob = numpyro.sample( 'transition_prob', dist.Dirichlet( np.broadcast_to(transition_prior, (num_categories, num_categories)))) emission_prob = numpyro.sample( 'emission_prob', dist.Dirichlet( np.broadcast_to(emission_prior, (num_categories, num_words)))) # models supervised data; # here we don't make any assumption about the first supervised category, in other words, # we place a flat/uniform prior on it. numpyro.sample('supervised_categories', dist.Categorical( transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:]) numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words) # computes log prob of unsupervised data transition_log_prob = np.log(transition_prob) emission_log_prob = np.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) # inject log_prob to potential function numpyro.factor('forward_log_prob', log_prob)
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) transition_prior = np.ones(num_categories) emission_prior = np.repeat(0.1, num_words) transition_prob = dist.Dirichlet(transition_prior).sample( key=rng_key_transition, sample_shape=(num_categories, )) emission_prob = dist.Dirichlet(emission_prior).sample( key=rng_key_emission, sample_shape=(num_categories, )) start_prob = np.repeat(1. / num_categories, num_categories) categories, words = [], [] for t in range(num_supervised_data + num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split( rng_key, 3) if t == 0 or t == num_supervised_data: category = dist.Categorical(start_prob).sample( key=rng_key_transition) else: category = dist.Categorical( transition_prob[category]).sample(key=rng_key_transition) word = dist.Categorical( emission_prob[category]).sample(key=rng_key_emission) categories.append(category) words.append(word) # split into supervised data and unsupervised data categories, words = np.stack(categories), np.stack(words) supervised_categories = categories[:num_supervised_data] supervised_words = words[:num_supervised_data] unsupervised_words = words[num_supervised_data:] return (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words)
def transition_fn(carry, y): x, w = carry x = numpyro.sample("x", dist.Categorical(probs_x[x])) w = numpyro.sample("w", dist.Categorical(probs_w[w])) numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y) # also test if scan's `ys` are recorded corrected return (x, w), x
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def semi_supervised_hmm(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words): num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0] transition_prob = sample('transition_prob', dist.Dirichlet( np.broadcast_to(transition_prior, (num_categories, num_categories)))) emission_prob = sample('emission_prob', dist.Dirichlet( np.broadcast_to(emission_prior, (num_categories, num_words)))) # models supervised data; # here we don't make any assumption about the first supervised category, in other words, # we place a flat/uniform prior on it. sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:]) sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words) # computes log prob of unsupervised data transition_log_prob = np.log(transition_prob) emission_log_prob = np.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) # inject log_prob to potential function # NB: This is a trick (uses an invalid value `0` of Multinomial distribution) to add an # additional term to potential energy. return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0)
def dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 2 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes))) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) # here we use Vindex to allow broadcasting for the second index `c` # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex with numpyro.plate("position", num_positions): numpyro.sample("y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations)
def item_difficulty(annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): eta = numpyro.sample( "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample( "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", annotations.shape[-1]): numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def model(z1=None, z2=None): p = numpyro.param("p", np.array([0.25, 0.75])) loc = numpyro.param("loc", jnp.array([-1.0, 1.0])) z1 = numpyro.sample("z1", dist.Categorical(p), obs=z1) with numpyro.plate("data[0]", 3): numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) with numpyro.plate("data[1]", 2): z2 = numpyro.sample("z2", dist.Categorical(p), obs=z2) numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
def model(data_x, data_w): x = w = 0 for i, y in markov(enumerate(data_x)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y) for i, y in markov(enumerate(data_w)): w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)
def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None
def model(z1=None, z2=None): p = numpyro.param("p", jnp.array([[0.25, 0.75], [0.1, 0.9]])) loc = numpyro.param("loc", jnp.array([-1.0, 1.0])) z1 = numpyro.sample("z1", dist.Categorical(p[0]), obs=z1) z2 = numpyro.sample("z2", dist.Categorical(p[z1]), obs=z2) logger.info("z1.shape = {}".format(z1.shape)) logger.info("z2.shape = {}".format(z2.shape)) with numpyro.plate("data", 3): numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None
def simulate_data( rng_key: np.ndarray, num_categories: int, num_words: int, num_supervised: int, num_unsupservised: int, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) transition_prior = jnp.ones(num_categories) emission_prior = jnp.repeat(0.1, num_words) transition_prob = dist.Dirichlet(transition_prior).sample( rng_key_transition, sample_shape=(num_categories, )) emission_prob = dist.Dirichlet(emission_prior).sample( rng_key_emission, sample_shape=(num_categories, )) start_prob = jnp.repeat(1.0 / num_categories, num_categories) category = 0 categories = [] words = [] for t in range(num_supervised + num_unsupservised): rng_key, rng_key_transition, rng_key_emission = random.split( rng_key, 3) if t == 0 or t == num_supervised: category = dist.Categorical(start_prob).sample(rng_key_transition) else: category = dist.Categorical( transition_prob[category]).sample(rng_key_transition) word = dist.Categorical( emission_prob[category]).sample(rng_key_emission) categories.append(category) words.append(word) # Split data into supervised and unsupervised categories = jnp.stack(categories) words = jnp.stack(words) supervised_categories = categories[:num_supervised] supervised_words = words[:num_supervised] unsupervised_words = words[num_supervised:] return ( transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words, )
def model2(): data = [np.array([-1.0, -1.0, 0.0]), np.array([-1.0, 1.0])] p = numpyro.param("p", np.array([0.25, 0.75])) loc = numpyro.sample("loc", dist.Normal(0, 1).expand([2]).to_event(1)) # FIXME results in infinite loop in transformeddist_to_funsor. # scale = numpyro.sample("scale", dist.LogNormal(0, 1)) z1 = numpyro.sample("z1", dist.Categorical(p)) scale = numpyro.sample("scale", dist.LogNormal(jnp.array([0.0, 1.0])[z1], 1)) with numpyro.plate("data[0]", 3): numpyro.sample("x1", dist.Normal(loc[z1], scale), obs=data[0]) with numpyro.plate("data[1]", 2): z2 = numpyro.sample("z2", dist.Categorical(p)) numpyro.sample("x2", dist.Normal(loc[z2], scale), obs=data[1])
def transition_fn(x, y): probs = transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D", D, dim=-1): w = numpyro.sample("w", dist.Bernoulli(0.6)) numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y) return x, None
def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) with numpyro.plate("num_clusters", K, dim=-1): cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)
def model(data): x = None for i, y in markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) return x
def __call__(self): assignment = numpyro.sample("assignment", dist.Categorical(self.weights)) loc = self.loc[assignment] cov = self.cov[assignment] nu_max = numpyro.sample("nu_max", self.nu_max) log_nu_max = jnp.log10(nu_max) teff = numpyro.sample("teff", self.teff) loc0101 = loc[0:2] cov0101 = jnp.array([[cov[0, 0], cov[0, 1]], [cov[1, 0], cov[1, 1]]]) L = jax.scipy.linalg.cho_factor(cov0101, lower=True) A = jax.scipy.linalg.cho_solve(L, jnp.array([log_nu_max, teff]) - loc0101) loc2323 = loc[2:] cov2323 = jnp.array([[cov[2, 2], cov[2, 3]], [cov[3, 2], cov[3, 3]]]) cov0123 = jnp.array([[cov[0, 2], cov[1, 2]], [cov[0, 3], cov[1, 3]]]) v = jax.scipy.linalg.cho_solve(L, cov0123.T) cond_loc = loc2323 + jnp.dot(cov0123, A) cond_cov = ( cov2323 - jnp.dot(cov0123, v) + self.noise * jnp.eye(2) # Add white noise ) numpyro.sample("log_tau", dist.MultivariateNormal(cond_loc, cond_cov))
def mace(positions, annotations): """ This model corresponds to the plate diagram in Figure 3 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators): epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10))) theta = numpyro.sample("theta", dist.Beta(0.5, 0.5)) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample( "c", dist.DiscreteUniform(0, num_classes - 1), infer={"enumerate": "parallel"}, ) with numpyro.plate("position", num_positions): s = numpyro.sample( "s", dist.Bernoulli(1 - theta[positions]), infer={"enumerate": "parallel"}, ) probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]) numpyro.sample("y", dist.Categorical(probs), obs=annotations)
def test_dirichlet_categorical(algo, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) samples = mcmc(warmup_steps, num_samples, init_params, constrain_fn=constrain_fn, progbar=False, print_summary=False, potential_fn=potential_fn, algo=algo, trajectory_length=1., dense_mass=dense_mass) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_gaussian_mixture_model(): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) with numpyro.plate("num_clusters", K, dim=-1): cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) true_cluster_means = jnp.array([1.0, 5.0, 10.0]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample( random.PRNGKey(0), (N, )) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm, init_strategy=init_to_median) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(samples["phi"].mean(0).sort(), true_mix_proportions, atol=0.05) assert_allclose(samples["cluster_means"].mean(0).sort(), true_cluster_means, atol=0.2)
def test_missing_plate(monkeypatch): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) # plate/to_event is missing here cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) true_cluster_means = jnp.array([1.0, 5.0, 10.0]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample( random.PRNGKey(0), (N, )) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) with pytest.raises(AssertionError, match="Missing plate statement"): mcmc.run(random.PRNGKey(2), data) monkeypatch.setattr(numpyro.infer.util, "_validate_model", lambda model_trace: None) with pytest.raises(Exception): mcmc.run(random.PRNGKey(2), data) assert len(_PYRO_STACK) == 0
def per_component_fun(j): log_prob_x_zj = jnp.sum(dist.Normal(mus[j], sigs[j]).log_prob(obs), axis=1).flatten() assert(jnp.atleast_1d(log_prob_x_zj).shape == (N,)) log_prob_zj = dist.Categorical(pis_prior).log_prob(j) log_prob = log_prob_x_zj + log_prob_zj assert(jnp.atleast_1d(log_prob).shape == (N,)) return log_prob
def test_dirichlet_categorical_x64(kernel_cls, dense_mass): num_warmup, num_samples = 100, 20000 def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration)) numpyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = jnp.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) if kernel_cls is BarkerMH: kernel = BarkerMH(model=model, dense_mass=dense_mass) else: kernel = kernel_cls(model, trajectory_length=1.0, dense_mass=dense_mass) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02) if "JAX_ENABLE_X64" in os.environ: assert samples["p_latent"].dtype == jnp.float64
def test_missing_plate(): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) # plate/to_event is missing here cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) true_cluster_means = jnp.array([1., 5., 10.]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample( random.PRNGKey(0), (N, )) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) with pytest.raises(AssertionError, match="Missing plate statement"): mcmc.run(random.PRNGKey(2), data)
def test_dirichlet_categorical(algo, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = sample('p_latent', dist.Dirichlet(concentration)) sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps, progbar=False, dense_mass=dense_mass) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=False) assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert hmc_states['p_latent'].dtype == np.float64
def model(z=None): p = numpyro.param("p", np.array([0.75, 0.25])) iz = numpyro.sample("z", dist.Categorical(p), obs=z) z = jnp.array([0.0, 1.0])[iz] logger.info("z.shape = {}".format(z.shape)) with numpyro.plate("data", 3): numpyro.sample("x", dist.Normal(z, 1.0), obs=data)
def sample_single(single_key): vals_rng_key, pis_rng_key = jax.random.split(single_key, 2) z = dist.Categorical(self._pis).sample(pis_rng_key) rng_keys = jax.random.split(vals_rng_key, len(self.dists)) vals = [dbn.sample(rng_keys[feat_idx])[z] \ for feat_idx, dbn in enumerate(self.dists)] return jnp.stack(vals).T, z