Esempio n. 1
0
def test_guide_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"), \
        pytest.raises(
            NotImplementedError,
            match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"):

        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)
Esempio n. 2
0
def test_model_enumerated_elbo(model, guide, data, history):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        if history > 1:
            pytest.xfail(
                reason="TraceMarkovEnum_ELBO does not yet support history > 1")

        model = infer.config_enumerate(model, default="parallel")
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
        expected_loss = elbo.loss_and_grads(model, guide, data, history, False)
        expected_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
        actual_loss = vectorized_elbo.loss_and_grads(model, guide, data,
                                                     history, True)
        actual_grads = (
            value.grad
            for name, value in pyro.get_param_store().named_parameters())

        assert_close(actual_loss, expected_loss)
        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            assert_close(actual_grad, expected_grad)