def test_scan_enum_scan_enum(): num_steps = 11 data_x = random.normal(random.PRNGKey(0), (num_steps,)) data_w = data_x[:-1] + 1 probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs_x = jnp.array([-1.0, 1.0]) locs_w = jnp.array([2.0, 3.0]) def model(data_x, data_w): x = w = 0 for i, y in markov(enumerate(data_x)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y) for i, y in markov(enumerate(data_w)): w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y) def fun_model(data_x, data_w): def transition_fn(name, probs, locs, x, y): x = numpyro.sample(name, dist.Categorical(probs[x])) numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y) return x, None scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x) scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data_x, data_w), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data_x, data_w), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_discrete_outside(): data = random.normal(random.PRNGKey(0), (10,)) probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]], [[0.7, 0.3], [0.6, 0.4]]]) locs = jnp.array([-1.0, 1.0]) def model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) x = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x])) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) def transition_fn(x, y): x = numpyro.sample("x", dist.Categorical(probs[w, x])) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_two_latents(): num_steps = 11 data = random.normal(random.PRNGKey(0), (num_steps,)) probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]]) probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = w = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y) def fun_model(data): def transition_fn(carry, y): x, w = carry x = numpyro.sample("x", dist.Categorical(probs_x[x])) w = numpyro.sample("w", dist.Categorical(probs_w[w])) numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y) # also test if scan's `ys` are recorded corrected return (x, w), x scan(transition_fn, (0, 0), data) actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_separated_plate_discrete(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]]) def model(data): x = 0 D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): probs = transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D_plate: w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6)) numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D", D, dim=-1): w = numpyro.sample("w", dist.Bernoulli(0.6)) numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y) return x, None scan(transition_fn, 0, data) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_plate(): N, D = 10, 3 data = random.normal(random.PRNGKey(0), (N, D)) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): with D_plate: probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] with numpyro.plate("D", D, dim=-1): x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None scan(transition_fn, None, data) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)
def test_scan_enum_one_latent(num_steps): data = random.normal(random.PRNGKey(0), (num_steps,)) init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data): x = None for i, y in markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) return x def fun_model(data): def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) return x, None x, collections = scan(transition_fn, None, data) assert collections is None return x actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) actual_last_x = enum(config_enumerate(fun_model))(data) expected_last_x = enum(config_enumerate(model))(data) assert_allclose(actual_last_x, expected_last_x)
def test_scan_history(history, T): def model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) x_prev = 0 x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr def fun_model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) def transition_fn(carry, y): x_prev, x_curr = carry probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) return (x_prev, x_curr), None (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) return x_prev, x_curr expected_log_joint = log_density(enum(config_enumerate(model)), (), {}, {})[0] actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) expected_x_prev, expected_x_curr = enum(config_enumerate(model))() actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))() assert_allclose(actual_x_prev, expected_x_prev) assert_allclose(actual_x_curr, expected_x_curr)
def test_scan_enum_separated_plates_same_dim(): N, D1, D2 = 10, 3, 4 data = random.normal(random.PRNGKey(0), (N, D1 + D2)) data1, data2 = data[:, :D1], data[:, D1:] init_probs = jnp.array([0.6, 0.4]) transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]]) locs = jnp.array([-1.0, 1.0]) def model(data1, data2): x = None D1_plate = numpyro.plate("D1", D1, dim=-1) D2_plate = numpyro.plate("D2", D2, dim=-1) for i, (y1, y2) in markov(enumerate(zip(data1, data2))): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D1_plate: numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1) with D2_plate: numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2) def fun_model(data1, data2): def transition_fn(x, y): y1, y2 = y probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) with numpyro.plate("D1", D1, dim=-1): numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1) with numpyro.plate("D2", D2, dim=-1): numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2) return x, None scan(transition_fn, None, (data1, data2)) actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data1, data2), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data1, data2), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint)