Ejemplo n.º 1
0
def test_prob(nderivs):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = torch.tensor([0.5, 1., 1.5])
    p = pyro.param("p", torch.tensor(0.25))

    @config_enumerate
    def model(num_particles):
        p = pyro.param("p")
        with pyro.plate("num_particles", num_particles, dim=-2):
            z = pyro.sample("z", dist.Bernoulli(p))
            with pyro.plate("data", 3):
                pyro.sample("x", dist.Normal(z, 1.), obs=data)

    def guide(num_particles):
        pass

    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1)

    posterior_model = infer_discrete(config_enumerate(model, "parallel"),
                                     first_available_dim=-3)
    posterior_trace = poutine.trace(posterior_model).get_trace(
        num_particles=num_particles)
    actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2)

    if nderivs == 0:
        assert_equal(expected_logprob, actual_logprob, prec=1e-3)
    elif nderivs == 1:
        expected_grad = grad(expected_logprob, [p])[0]
        actual_grad = grad(actual_logprob, [p])[0]
        assert_equal(expected_grad, actual_grad, prec=1e-3)
Ejemplo n.º 2
0
def test_distribution_1(temperature):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = torch.tensor([1., 2., 3.])

    @config_enumerate
    def model(num_particles=1, z=None):
        p = pyro.param("p", torch.tensor(0.25))
        with pyro.plate("num_particles", num_particles, dim=-2):
            z = pyro.sample("z", dist.Bernoulli(p), obs=z)
            logger.info("z.shape = {}".format(z.shape))
            with pyro.plate("data", 3):
                pyro.sample("x", dist.Normal(z, 1.), obs=data)

    first_available_dim = -3
    sampled_model = infer_discrete(model, first_available_dim, temperature)
    sampled_trace = poutine.trace(sampled_model).get_trace(num_particles)
    conditioned_traces = {
        z: poutine.trace(model).get_trace(z=torch.tensor(z))
        for z in [0., 1.]
    }

    # Check  posterior over z.
    actual_z_mean = sampled_trace.nodes["z"]["value"].mean()
    if temperature:
        expected_z_mean = 1 / (1 +
                               (conditioned_traces[0].log_prob_sum() -
                                conditioned_traces[1].log_prob_sum()).exp())
    else:
        expected_z_mean = (conditioned_traces[1].log_prob_sum() >
                           conditioned_traces[0].log_prob_sum()).float()
    assert_equal(actual_z_mean, expected_z_mean, prec=1e-2)
Ejemplo n.º 3
0
def test_hmm_smoke(temperature, length):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
        means = torch.arange(float(hidden_dim))
        states = [0]
        for t in pyro.markov(range(len(data))):
            states.append(
                pyro.sample("states_{}".format(t),
                            dist.Categorical(transition[states[-1]])))
            data[t] = pyro.sample("obs_{}".format(t),
                                  dist.Normal(means[states[-1]], 1.),
                                  obs=data[t])
        return states, data

    true_states, data = hmm([None] * length)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer_discrete(config_enumerate(hmm),
                             first_available_dim=-1,
                             temperature=temperature)
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
Ejemplo n.º 4
0
def test_warning():
    data = torch.randn(4)

    def model():
        x = pyro.sample("x", dist.Categorical(torch.ones(3)))
        with pyro.plate("data", len(data)):
            pyro.sample("obs", dist.Normal(x.float(), 1), obs=data)

    model_1 = infer_discrete(model, first_available_dim=-2)
    model_2 = infer_discrete(model,
                             first_available_dim=-2,
                             strict_enumeration_warning=False)
    model_3 = infer_discrete(config_enumerate(model), first_available_dim=-2)

    # model_1 should raise warnings.
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        model_1()
    assert w, 'No warnings were raised'

    # model_2 and model_3 should both be valid.
    model_2()
    model_3()
Ejemplo n.º 5
0
def test_distribution_3(temperature):
    #       +---------+  +---------------+
    #  z1 --|--> x1   |  |  z2 ---> x2   |
    #       |       3 |  |             2 |
    #       +---------+  +---------------+
    num_particles = 10000
    data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])]

    @config_enumerate
    def model(num_particles=1, z1=None, z2=None):
        p = pyro.param("p", torch.tensor([0.25, 0.75]))
        loc = pyro.param("loc", torch.tensor([-1., 1.]))
        with pyro.plate("num_particles", num_particles, dim=-2):
            z1 = pyro.sample("z1", dist.Categorical(p), obs=z1)
            with pyro.plate("data[0]", 3):
                pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0])
            with pyro.plate("data[1]", 2):
                z2 = pyro.sample("z2", dist.Categorical(p), obs=z2)
                pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1])

    first_available_dim = -3
    sampled_model = infer_discrete(model, first_available_dim, temperature)
    sampled_trace = poutine.trace(sampled_model).get_trace(num_particles)
    conditioned_traces = {
        (z1, z20, z21):
        poutine.trace(model).get_trace(z1=torch.tensor(z1),
                                       z2=torch.tensor([z20, z21]))
        for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1]
    }

    # Check joint posterior over (z1, z2[0], z2[1]).
    actual_probs = torch.empty(2, 2, 2)
    expected_probs = torch.empty(2, 2, 2)
    for (z1, z20, z21), tr in conditioned_traces.items():
        expected_probs[z1, z20, z21] = tr.log_prob_sum().exp()
        actual_probs[z1, z20, z21] = (
            (sampled_trace.nodes["z1"]["value"] == z1) &
            (sampled_trace.nodes["z2"]["value"][..., :1] == z20) &
            (sampled_trace.nodes["z2"]["value"][..., 1:]
             == z21)).float().mean()
    if temperature:
        expected_probs = expected_probs / expected_probs.sum()
    else:
        argmax = expected_probs.reshape(-1).max(0)[1]
        expected_probs[:] = 0
        expected_probs.reshape(-1)[argmax] = 1
    assert_equal(expected_probs.reshape(-1),
                 actual_probs.reshape(-1),
                 prec=1e-2)