Example #1
0
def test_scan_enum_scan_enum():
    num_steps = 11
    data_x = random.normal(random.PRNGKey(0), (num_steps,))
    data_w = data_x[:-1] + 1
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs_x = jnp.array([-1.0, 1.0])
    locs_w = jnp.array([2.0, 3.0])

    def model(data_x, data_w):
        x = w = 0
        for i, y in markov(enumerate(data_x)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y)

        for i, y in markov(enumerate(data_w)):
            w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)

    def fun_model(data_x, data_w):
        def transition_fn(name, probs, locs, x, y):
            x = numpyro.sample(name, dist.Categorical(probs[x]))
            numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x)
        scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data_x, data_w), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data_x, data_w), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #2
0
def test_scan_enum_discrete_outside():
    data = random.normal(random.PRNGKey(0), (10,))
    probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]],
                       [[0.7, 0.3], [0.6, 0.4]]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))
        x = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))

        def transition_fn(x, y):
            x = numpyro.sample("x", dist.Categorical(probs[w, x]))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #3
0
def test_scan_enum_two_latents():
    num_steps = 11
    data = random.normal(random.PRNGKey(0), (num_steps,))
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = w = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y)

    def fun_model(data):
        def transition_fn(carry, y):
            x, w = carry
            x = numpyro.sample("x", dist.Categorical(probs_x[x]))
            w = numpyro.sample("w", dist.Categorical(probs_w[w]))
            numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y)
            # also test if scan's `ys` are recorded corrected
            return (x, w), x

        scan(transition_fn, (0, 0), data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #4
0
def test_scan_enum_separated_plate_discrete():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = 0
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            probs = transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D_plate:
                w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6))
                numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D", D, dim=-1):
                w = numpyro.sample("w", dist.Bernoulli(0.6))
                numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #5
0
def test_scan_enum_plate():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            with D_plate:
                probs = init_probs if x is None else transition_probs[x]
                x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
                numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            with numpyro.plate("D", D, dim=-1):
                x = numpyro.sample("x", dist.Categorical(probs))
                numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, None, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #6
0
def test_scan_enum_one_latent(num_steps):
    data = random.normal(random.PRNGKey(0), (num_steps,))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        for i, y in markov(enumerate(data)):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
        return x

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        x, collections = scan(transition_fn, None, data)
        assert collections is None
        return x

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    actual_last_x = enum(config_enumerate(fun_model))(data)
    expected_last_x = enum(config_enumerate(model))(data)
    assert_allclose(actual_last_x, expected_last_x)
Example #7
0
def test_scan_history(history, T):
    def model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))
        x_prev = 0
        x_curr = 0
        for t in markov(range(T), history=history):
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t),
                                                    dist.Bernoulli(probs))
            numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0)
        return x_prev, x_curr

    def fun_model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))

        def transition_fn(carry, y):
            x_prev, x_curr = carry
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs))
            numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y)
            return (x_prev, x_curr), None

        (x_prev, x_curr), _ = scan(transition_fn, (0, 0),
                                   jnp.zeros(T),
                                   history=history)
        return x_prev, x_curr

    expected_log_joint = log_density(enum(config_enumerate(model)), (), {},
                                     {})[0]
    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {},
                                   {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    expected_x_prev, expected_x_curr = enum(config_enumerate(model))()
    actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))()
    assert_allclose(actual_x_prev, expected_x_prev)
    assert_allclose(actual_x_curr, expected_x_curr)
Example #8
0
def test_scan_enum_separated_plates_same_dim():
    N, D1, D2 = 10, 3, 4
    data = random.normal(random.PRNGKey(0), (N, D1 + D2))
    data1, data2 = data[:, :D1], data[:, D1:]
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data1, data2):
        x = None
        D1_plate = numpyro.plate("D1", D1, dim=-1)
        D2_plate = numpyro.plate("D2", D2, dim=-1)
        for i, (y1, y2) in markov(enumerate(zip(data1, data2))):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D1_plate:
                numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1)
            with D2_plate:
                numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2)

    def fun_model(data1, data2):
        def transition_fn(x, y):
            y1, y2 = y
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D1", D1, dim=-1):
                numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1)
            with numpyro.plate("D2", D2, dim=-1):
                numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2)
            return x, None

        scan(transition_fn, None, (data1, data2))

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2),
                                   (data1, data2), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2),
                                     (data1, data2), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)