def test_vnmc_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) vnmc_eig(linear_model, one_point_design, "y", "w", num_samples=[9, 3], num_steps=250, guide=posterior_guide, optim=optim.Adam({"lr": 0.1})) # Finesse (small learning rate) estimated_eig = vnmc_eig(linear_model, one_point_design, "y", "w", num_samples=[9, 3], num_steps=250, guide=posterior_guide, optim=optim.Adam({"lr": 0.01}), final_num_samples=[500, 100]) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_marginal_likelihood_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # Pre-train (large learning rate) marginal_likelihood_eig(linear_model, one_point_design, "y", "w", num_samples=10, num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, optim=optim.Adam({"lr": 0.1})) # Finesse (small learning rate) estimated_eig = marginal_likelihood_eig(linear_model, one_point_design, "y", "w", num_samples=10, num_steps=250, marginal_guide=marginal_guide, cond_guide=likelihood_guide, optim=optim.Adam({"lr": 0.01}), final_num_samples=500) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_dv_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() donsker_varadhan_eig( linear_model, one_point_design, "y", "w", num_samples=100, num_steps=500, T=dv_critic, optim=optim.Adam({"lr": 0.1}), ) estimated_eig = donsker_varadhan_eig( linear_model, one_point_design, "y", "w", num_samples=100, num_steps=650, T=dv_critic, optim=optim.Adam({"lr": 0.001}), final_num_samples=2000, ) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_lfire_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() estimated_eig = lfire_eig(linear_model, one_point_design, "y", "w", num_y_samples=2, num_theta_samples=50, num_steps=1200, classifier=make_lfire_classifier(50), optim=optim.Adam({"lr": 0.0025}), final_num_samples=100) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_laplace_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() # You can use 1 final sample here because linear models have a posterior entropy that is independent of `y` estimated_eig = laplace_eig(linear_model, one_point_design, "y", "w", guide=laplace_guide, num_steps=250, final_num_samples=1, optim=optim.Adam({"lr": 0.05}), loss=Trace_ELBO().differentiable_loss) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_nmc_eig_linear_model(linear_model, one_point_design): pyro.set_rng_seed(42) pyro.clear_param_store() estimated_eig = nmc_eig(linear_model, one_point_design, "y", "w", M=60, N=60 * 60) expected_eig = linear_model_ground_truth(linear_model, one_point_design, "y", "w") assert_equal(estimated_eig, expected_eig, prec=5e-2)
def test_eig_lm(model, design, observation_labels, target_labels, estimator, args, eig, allow_error): pyro.set_rng_seed(42) pyro.clear_param_store() y = estimator(model, design, observation_labels, target_labels, *args) if model is bernoulli_model: y_true = bernoulli_ground_truth(model, design, observation_labels, target_labels, eig=eig) else: y_true = linear_model_ground_truth(model, design, observation_labels, target_labels, eig=eig) logger.debug(estimator.__name__) logger.debug(y) logger.debug(y_true) error = torch.max(torch.abs(y - y_true)) assert error < allow_error