def test_laplace_approximation_custom_hessian(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10)) mu = a + b * x numpyro.sample("y", dist.Normal(mu, 1), obs=y) x = random.normal(random.PRNGKey(0), (100, )) y = 1 + 2 * x guide = AutoLaplaceApproximation( model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False) guide.get_transform(svi_result.params)
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()) d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()) # Model def model(M, A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3 = AutoLaplaceApproximation(model) svi = SVI(model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values) p5_3, losses = svi.run(random.PRNGKey(0), 1000) post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, )) # Posterior param_names = {'a', 'bA', 'bM', 'sigma'} for p in param_names: print(f'posterior for {p}')
ax=ax) pml.savefig(f'multicollinear_joint_post_{method}.pdf') plt.title(method) plt.show() sum_blbr = post["bl"] + post["br"] fig, ax = plt.subplots() az.plot_kde(sum_blbr, label="sum of bl and br", ax=ax) plt.title(method) pml.savefig(f'multicollinear_sum_post_{method}.pdf') plt.show() # Laplace fit m6_1 = AutoLaplaceApproximation(model) svi = SVI(model, m6_1, optim.Adam(0.1), Trace_ELBO(), leg_left=df.leg_left.values, leg_right=df.leg_right.values, height=df.height.values, br_positive=False) p6_1, losses = svi.run(random.PRNGKey(0), 2000) post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000, )) analyze_post(post_laplace, 'laplace') # MCMC fit # code from p298 (code 9.28) of rethinking2