예제 #1
0
def test_marginal_likelihood_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    # Pre-train (large learning rate)
    marginal_likelihood_eig(linear_model,
                            one_point_design,
                            "y",
                            "w",
                            num_samples=10,
                            num_steps=250,
                            marginal_guide=marginal_guide,
                            cond_guide=likelihood_guide,
                            optim=optim.Adam({"lr": 0.1}))
    # Finesse (small learning rate)
    estimated_eig = marginal_likelihood_eig(linear_model,
                                            one_point_design,
                                            "y",
                                            "w",
                                            num_samples=10,
                                            num_steps=250,
                                            marginal_guide=marginal_guide,
                                            cond_guide=likelihood_guide,
                                            optim=optim.Adam({"lr": 0.01}),
                                            final_num_samples=500)
    expected_eig = linear_model_ground_truth(linear_model, one_point_design,
                                             "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
예제 #2
0
def test_marginal_likelihood_finite_space_model(
    finite_space_model, one_point_design, true_eig
):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    # Pre-train (large learning rate)
    marginal_likelihood_eig(
        finite_space_model,
        one_point_design,
        "y",
        "theta",
        num_samples=10,
        num_steps=250,
        marginal_guide=marginal_guide,
        cond_guide=likelihood_guide,
        optim=optim.Adam({"lr": 0.1}),
    )
    # Finesse (small learning rate)
    estimated_eig = marginal_likelihood_eig(
        finite_space_model,
        one_point_design,
        "y",
        "theta",
        num_samples=10,
        num_steps=250,
        marginal_guide=marginal_guide,
        cond_guide=likelihood_guide,
        optim=optim.Adam({"lr": 0.01}),
        final_num_samples=1000,
    )
    assert_equal(estimated_eig, true_eig, prec=1e-2)