コード例 #1
0
def test_mcmc_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # Draw from prior rather than trying to sample from mcmc posterior.
    # This has the wrong distribution but the right type for tests.
    mcmc_trace = handlers.trace(
        handlers.block(handlers.enum(infer.config_enumerate(model)),
                       expose=["loc", "scale"])).get_trace()
    mcmc_data = {
        name: site["value"]
        for name, site in mcmc_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), mcmc_trace),
            handlers.condition(infer.config_enumerate(model), mcmc_data),
            temperature=temperature,
        ), ).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
コード例 #2
0
def test_svi_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # This has the wrong distribution but the right type for tests.
    guide = AutoNormal(
        handlers.enum(
            handlers.block(infer.config_enumerate(model),
                           expose=["loc", "scale"])))
    guide()  # Initialize but don't bother to train.
    guide_trace = handlers.trace(guide).get_trace()
    guide_data = {
        name: site["value"]
        for name, site in guide_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), guide_trace)
            handlers.condition(infer.config_enumerate(model), guide_data),
            temperature=temperature,
        )).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
コード例 #3
0
def test_model_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        model = infer.config_enumerate(model, default="parallel")
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
コード例 #4
0
def test_hmm_smoke(length, temperature):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
        means = torch.arange(float(hidden_dim))
        states = [0]
        for t in pyro.markov(range(len(data))):
            states.append(
                pyro.sample("states_{}".format(t),
                            dist.Categorical(transition[states[-1]])))
            data[t] = pyro.sample("obs_{}".format(t),
                                  dist.Normal(means[states[-1]], 1.0),
                                  obs=data[t])
        return states, data

    true_states, data = hmm([None] * length)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer.infer_discrete(infer.config_enumerate(hmm),
                                   temperature=temperature)
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
コード例 #5
0
def _guide_from_model(model):
    try:
        with pyro_backend("contrib.funsor"):
            return handlers.block(
                infer.config_enumerate(model, default="parallel"),
                lambda msg: msg.get("is_observed", False))
    except KeyError:  # for test collection without funsor
        return model
コード例 #6
0
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy):
    def model():
        x = pyro.sample("x0", dist.Categorical(pyro.param("q0")))
        with pyro.plate("local", 3):
            for i in range(1, depth):
                x = pyro.sample(
                    "x{}".format(i),
                    dist.Categorical(pyro.param("q{}".format(i))[..., x, :]))
            with pyro.plate("data", 4):
                pyro.sample("y",
                            dist.Bernoulli(pyro.param("qy")[..., x]),
                            obs=data)

    with pyro_backend("pyro"):
        # initialize
        qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))]
        for i in range(1, depth):
            qs.append(
                pyro.param("q{}".format(i),
                           torch.randn(2, 2).abs().detach().requires_grad_(),
                           constraint=constraints.simplex))
        qs.append(
            pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True)))
        qs = [q.unconstrained() for q in qs]
        data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype,
                                           device=qs[-1].device)

    with pyro_backend("pyro"):
        elbo = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
        enum_model = infer.config_enumerate(model,
                                            default="parallel",
                                            expand=False,
                                            num_samples=num_samples,
                                            tmc=tmc_strategy)
        expected_loss = (
            -elbo.differentiable_loss(enum_model, lambda: None)).exp()
        expected_grads = grad(expected_loss, qs)

    with pyro_backend("contrib.funsor"):
        tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
        tmc_model = infer.config_enumerate(model,
                                           default="parallel",
                                           expand=False,
                                           num_samples=num_samples,
                                           tmc=tmc_strategy)
        actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp()
        actual_grads = grad(actual_loss, qs)

    prec = 0.05
    assert_equal(actual_loss,
                 expected_loss,
                 prec=prec,
                 msg="".join([
                     "\nexpected loss = {}".format(expected_loss),
                     "\n  actual loss = {}".format(actual_loss),
                 ]))

    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        assert_equal(actual_grad,
                     expected_grad,
                     prec=prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))
コード例 #7
0
def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting,
                                    expand, guide_type, reparameterized,
                                    tmc_strategy):
    def model(reparameterized):
        Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))
        pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1)))

    def factorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal
        pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            pyro.sample("x{}".format(i),
                        Normal(0., math.sqrt(float(i + 1) / depth)))

    def nonfactorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))

    with pyro_backend("contrib.funsor"):
        # compare reparameterized and nonreparameterized gradient estimates
        q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True))
        qs = (q2.unconstrained(), )

        tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
        tmc_model = infer.config_enumerate(model,
                                           default="parallel",
                                           expand=expand,
                                           num_samples=num_samples,
                                           tmc=tmc_strategy)
        guide = factorized_guide if guide_type == "factorized" else \
            nonfactorized_guide if guide_type == "nonfactorized" else \
            lambda *args: None
        tmc_guide = infer.config_enumerate(guide,
                                           default="parallel",
                                           expand=expand,
                                           num_samples=num_samples,
                                           tmc=tmc_strategy)

        # convert to linear space for unbiasedness
        actual_loss = (-tmc.differentiable_loss(tmc_model, tmc_guide,
                                                reparameterized)).exp()
        actual_grads = grad(actual_loss, qs)

    # gold values from Funsor
    expected_grads = (torch.tensor({
        1: 0.0999,
        2: 0.0860,
        3: 0.0802,
        4: 0.0771
    }[depth]), )

    grad_prec = 0.05 if reparameterized else 0.1

    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        print(actual_loss)
        assert_equal(actual_grad,
                     expected_grad,
                     prec=grad_prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))
コード例 #8
0
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history,
                           use_replay):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        with handlers.enum():
            enum_model = infer.config_enumerate(model, default="parallel")
            # sequential factors
            trace = handlers.trace(enum_model).get_trace(
                weeks_data, days_data, history, False)

            # vectorized trace
            if use_replay:
                guide_trace = handlers.trace(
                    _guide_from_model(model)).get_trace(
                        weeks_data, days_data, history, True)
                vectorized_trace = handlers.trace(
                    handlers.replay(model, trace=guide_trace)).get_trace(
                        weeks_data, days_data, history, True)
            else:
                vectorized_trace = handlers.trace(enum_model).get_trace(
                    weeks_data, days_data, history, True)

        factors = list()
        # sequential weeks factors
        for i in range(len(weeks_data)):
            for v in vars1:
                factors.append(trace.nodes["{}_{}".format(
                    v, i)]["funsor"]["log_prob"])
        # sequential days factors
        for j in range(len(days_data)):
            for v in vars2:
                factors.append(trace.nodes["{}_{}".format(
                    v, j)]["funsor"]["log_prob"])

        vectorized_factors = list()
        # vectorized weeks factors
        for i in range(history):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(weeks_data)):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(weeks_data)))]["funsor"]["log_prob"](**{
                                     "weeks":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(weeks_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars1
                                 }))
        # vectorized days factors
        for i in range(history):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(days_data)):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(days_data)))]["funsor"]["log_prob"](**{
                                     "days":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(days_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars2
                                 }))

        # assert correct factors
        for f1, f2 in zip(factors, vectorized_factors):
            assert_close(f2, f1.align(tuple(f2.inputs)))

        # assert correct step

        expected_measure_vars = frozenset()
        actual_weeks_step = vectorized_trace.nodes["weeks"]["value"]
        # expected step: assume that all but the last var is markov
        expected_weeks_step = frozenset()
        for v in vars1[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1))
            expected_weeks_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        actual_days_step = vectorized_trace.nodes["days"]["value"]
        # expected step: assume that all but the last var is markov
        expected_days_step = frozenset()
        for v in vars2[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1))
            expected_days_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        assert actual_weeks_step == expected_weeks_step
        assert actual_days_step == expected_days_step

        # check measure_vars
        actual_measure_vars = terms_from_trace(
            vectorized_trace)["measure_vars"]
        assert actual_measure_vars == expected_measure_vars