def test_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) states = [0] for t in markov(range(len(data))): states.append( numpyro.sample( "states_{}".format(t), dist.Categorical(transition[states[-1]]) ) ) data[t] = numpyro.sample( "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t] ) return states, data true_states, data = handlers.seed(hmm, 0)([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer_discrete( config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1) ) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def test_distribution_3(temperature): # +---------+ +---------------+ # z1 --|--> x1 | | z2 ---> x2 | # | 3 | | 2 | # +---------+ +---------------+ num_particles = 10000 data = [np.array([-1.0, -1.0, 0.0]), np.array([-1.0, 1.0])] @config_enumerate def model(z1=None, z2=None): p = numpyro.param("p", np.array([0.25, 0.75])) loc = numpyro.param("loc", jnp.array([-1.0, 1.0])) z1 = numpyro.sample("z1", dist.Categorical(p), obs=z1) with numpyro.plate("data[0]", 3): numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) with numpyro.plate("data[1]", 2): z2 = numpyro.sample("z2", dist.Categorical(p), obs=z2) numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1]) first_available_dim = -3 vectorized_model = ( model if temperature == 0 else vectorize_model(model, num_particles, dim=-2) ) sampled_model = infer_discrete( vectorized_model, first_available_dim, temperature, rng_key=random.PRNGKey(1) ) sampled_trace = handlers.trace(sampled_model).get_trace() conditioned_traces = { (z1, z20, z21): handlers.trace(model).get_trace( z1=np.array(z1), z2=np.array([z20, z21]) ) for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1] } # Check joint posterior over (z1, z2[0], z2[1]). actual_probs = np.zeros((2, 2, 2)) expected_probs = np.zeros((2, 2, 2)) for (z1, z20, z21), tr in conditioned_traces.items(): expected_probs[z1, z20, z21] = jnp.exp(log_prob_sum(tr)) actual_probs[z1, z20, z21] = ( ( (sampled_trace["z1"]["value"] == z1) & (sampled_trace["z2"]["value"][..., :1] == z20) & (sampled_trace["z2"]["value"][..., 1:] == z21) ) .astype(float) .mean() ) if temperature: expected_probs = expected_probs / expected_probs.sum() else: argmax = expected_probs.reshape(-1).argmax() expected_max = expected_probs.reshape(-1)[argmax] actual_max = np.exp(log_prob_sum(sampled_trace)) assert_allclose(expected_max, actual_max, atol=1e-5) expected_probs[:] = 0 expected_probs.reshape(-1)[argmax] = 1 assert_allclose(expected_probs.reshape(-1), actual_probs.reshape(-1), atol=1e-2)
def test_scan_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) def transition_fn(state, y): state = numpyro.sample("states", dist.Categorical(transition[state])) y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y) return state, (state, y) _, (states, data) = scan(transition_fn, 0, data, length=length) return [0] + [s for s in states], data true_states, data = handlers.seed(hmm, 0)(None) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer_discrete(config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def test_distribution_1(temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = np.array([1.0, 2.0, 3.0]) @config_enumerate def model(z=None): p = numpyro.param("p", np.array([0.75, 0.25])) iz = numpyro.sample("z", dist.Categorical(p), obs=z) z = jnp.array([0.0, 1.0])[iz] logger.info("z.shape = {}".format(z.shape)) with numpyro.plate("data", 3): numpyro.sample("x", dist.Normal(z, 1.0), obs=data) first_available_dim = -3 vectorized_model = ( model if temperature == 0 else vectorize_model(model, num_particles, dim=-2) ) sampled_model = infer_discrete( vectorized_model, first_available_dim, temperature, rng_key=random.PRNGKey(1) ) sampled_trace = handlers.trace(sampled_model).get_trace() conditioned_traces = { z: handlers.trace(model).get_trace(z=np.array(z)) for z in [0, 1] } # Check posterior over z. actual_z_mean = sampled_trace["z"]["value"].astype(float).mean() if temperature: expected_z_mean = 1 / ( 1 + np.exp( log_prob_sum(conditioned_traces[0]) - log_prob_sum(conditioned_traces[1]) ) ) else: expected_z_mean = ( log_prob_sum(conditioned_traces[1]) > log_prob_sum(conditioned_traces[0]) ).astype(float) expected_max = max(log_prob_sum(t) for t in conditioned_traces.values()) actual_max = log_prob_sum(sampled_trace) assert_allclose(expected_max, actual_max, atol=1e-5) assert_allclose(actual_z_mean, expected_z_mean, atol=1e-2 if temperature else 1e-5)
def test_mcmc_model_side_enumeration(model, temperature): mcmc = infer.MCMC(infer.NUTS(model), 0, 1) mcmc.run(random.PRNGKey(0)) mcmc_data = { k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"] } # MAP estimate discretes, conditioned on posterior sampled continous latents. model = handlers.seed(model, rng_seed=1) actual_trace = handlers.trace( infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(config_enumerate(model), mcmc_trace), handlers.condition(config_enumerate(model), mcmc_data), temperature=temperature, rng_key=random.PRNGKey(1), ) ).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace) == set(expected_trace)