Esempio n. 1
0
def test_vnmc_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    # Pre-train (large learning rate)
    vnmc_eig(linear_model,
             one_point_design,
             "y",
             "w",
             num_samples=[9, 3],
             num_steps=250,
             guide=posterior_guide,
             optim=optim.Adam({"lr": 0.1}))
    # Finesse (small learning rate)
    estimated_eig = vnmc_eig(linear_model,
                             one_point_design,
                             "y",
                             "w",
                             num_samples=[9, 3],
                             num_steps=250,
                             guide=posterior_guide,
                             optim=optim.Adam({"lr": 0.01}),
                             final_num_samples=[500, 100])
    expected_eig = linear_model_ground_truth(linear_model, one_point_design,
                                             "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 2
0
def test_marginal_likelihood_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    # Pre-train (large learning rate)
    marginal_likelihood_eig(linear_model,
                            one_point_design,
                            "y",
                            "w",
                            num_samples=10,
                            num_steps=250,
                            marginal_guide=marginal_guide,
                            cond_guide=likelihood_guide,
                            optim=optim.Adam({"lr": 0.1}))
    # Finesse (small learning rate)
    estimated_eig = marginal_likelihood_eig(linear_model,
                                            one_point_design,
                                            "y",
                                            "w",
                                            num_samples=10,
                                            num_steps=250,
                                            marginal_guide=marginal_guide,
                                            cond_guide=likelihood_guide,
                                            optim=optim.Adam({"lr": 0.01}),
                                            final_num_samples=500)
    expected_eig = linear_model_ground_truth(linear_model, one_point_design,
                                             "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 3
0
def test_dv_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    donsker_varadhan_eig(
        linear_model,
        one_point_design,
        "y",
        "w",
        num_samples=100,
        num_steps=500,
        T=dv_critic,
        optim=optim.Adam({"lr": 0.1}),
    )
    estimated_eig = donsker_varadhan_eig(
        linear_model,
        one_point_design,
        "y",
        "w",
        num_samples=100,
        num_steps=650,
        T=dv_critic,
        optim=optim.Adam({"lr": 0.001}),
        final_num_samples=2000,
    )
    expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 4
0
def test_lfire_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    estimated_eig = lfire_eig(linear_model, one_point_design, "y", "w", num_y_samples=2, num_theta_samples=50,
                              num_steps=1200, classifier=make_lfire_classifier(50), optim=optim.Adam({"lr": 0.0025}),
                              final_num_samples=100)
    expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 5
0
def test_laplace_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    # You can use 1 final sample here because linear models have a posterior entropy that is independent of `y`
    estimated_eig = laplace_eig(linear_model, one_point_design, "y", "w",
                                guide=laplace_guide, num_steps=250, final_num_samples=1,
                                optim=optim.Adam({"lr": 0.05}),
                                loss=Trace_ELBO().differentiable_loss)
    expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 6
0
def test_nmc_eig_linear_model(linear_model, one_point_design):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    estimated_eig = nmc_eig(linear_model,
                            one_point_design,
                            "y",
                            "w",
                            M=60,
                            N=60 * 60)
    expected_eig = linear_model_ground_truth(linear_model, one_point_design,
                                             "y", "w")
    assert_equal(estimated_eig, expected_eig, prec=5e-2)
Esempio n. 7
0
def test_eig_lm(model, design, observation_labels, target_labels, estimator,
                args, eig, allow_error):
    pyro.set_rng_seed(42)
    pyro.clear_param_store()
    y = estimator(model, design, observation_labels, target_labels, *args)
    if model is bernoulli_model:
        y_true = bernoulli_ground_truth(model,
                                        design,
                                        observation_labels,
                                        target_labels,
                                        eig=eig)
    else:
        y_true = linear_model_ground_truth(model,
                                           design,
                                           observation_labels,
                                           target_labels,
                                           eig=eig)
    logger.debug(estimator.__name__)
    logger.debug(y)
    logger.debug(y_true)
    error = torch.max(torch.abs(y - y_true))
    assert error < allow_error