def test_mutable_state(stable_update, num_particles, elbo): def model(): x = numpyro.sample("x", dist.Normal(-1, 1)) numpyro_mutable("x1p", x + 1) def guide(): loc = numpyro.param("loc", 0.0) p = numpyro_mutable("loc1p", {"value": None}) # we can modify the content of `p` if it is a dict p["value"] = loc + 2 numpyro.sample("x", dist.Normal(loc, 0.1)) svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles)) if num_particles > 1: with pytest.raises(ValueError, match="mutable state"): svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) return svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) params = svi_result.params mutable_state = svi_result.state.mutable_state assert set(mutable_state) == {"x1p", "loc1p"} assert_allclose(mutable_state["loc1p"]["value"], params["loc"] + 2, atol=0.1) # here, the initial loc has value 0., hence x1p will have init value near 1 # it won't be updated during SVI run because it is not a mutable state assert_allclose(mutable_state["x1p"], 1.0, atol=0.2)
def test_run_with_small_num_steps(num_steps): def model(): pass def guide(): pass svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) svi.run(random.PRNGKey(0), num_steps)
def test_plate_inconsistent(size, dim): def model(): with numpyro.plate("a", 10, dim=-1): numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("a", size, dim=dim): numpyro.sample("y", dist.Normal(0, 1)) guide = AutoDelta(model) svi = SVI(model, guide, numpyro.optim.Adam(step_size=0.1), Trace_ELBO()) with pytest.raises(AssertionError, match="has inconsistent dim or size"): svi.run(random.PRNGKey(0), 10)
def test_svi_discrete_latent(): def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) with pytest.warns(UserWarning, match="SVI does not support models with discrete"): svi.run(random.PRNGKey(0), 10)
def find_map( self, num_steps: int = 10000, handlers: Optional[list] = None, reparam: Union[str, hdl.reparam] = "auto", svi_kwargs: dict = {}, ): """EXPERIMENTAL: find MAP. Args: num_steps (int): [description]. Defaults to 10000. handlers (list, optional): [description]. Defaults to None. reparam (str, or numpyro.handlers.reparam): [description]. Defaults to 'auto'. svi_kwargs (dict): [description]. Defaults to {}. """ model = self._add_handlers_to_model(handlers=handlers, reparam=reparam) guide = numpyro.infer.autoguide.AutoDelta(model) optim = svi_kwargs.pop("optim", numpyro.optim.Minimize()) loss = svi_kwargs.pop("loss", numpyro.infer.Trace_ELBO()) map_svi = SVI(model, guide, optim, loss=loss, **svi_kwargs) rng_key, self._rng_key = random.split(self._rng_key) map_result = map_svi.run(rng_key, num_steps, self.n, nu=self.nu, nu_err=self.nu_err) self._map_loss = map_result.losses self._map_guide = map_svi.guide self._map_params = map_result.params
def fit_svi(model, n_draws=1000, autoguide=AutoLaplaceApproximation, loss=Trace_ELBO(), optim=optim.Adam(step_size=.00001), num_warmup=2000, use_gpu=False, num_chains=1, progress_bar=False, sampler=None, **kwargs): select_device(use_gpu, num_chains) guide = autoguide(model) svi = SVI(model=model, guide=guide, loss=loss, optim=optim, **kwargs) # Experimental interface: svi_result = svi.run(jax.random.PRNGKey(0), num_steps=num_warmup, stable_update=True, progress_bar=progress_bar) # Old: post = guide.sample_posterior(jax.random.PRNGKey(1), params=svi_result.params, sample_shape=(1, n_draws)) # New: #predictive = Predictive(guide, params=svi_result.params, num_samples=n_draws) #post = predictive(jax.random.PRNGKey(1), **kwargs) # Old interface: # init_state = svi.init(jax.random.PRNGKey(0)) # state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(n_draws))#, length=num_warmup) # svi_params = svi.get_params(state) # post = guide.sample_posterior(jax.random.PRNGKey(1), svi_params, (1, n_draws)) trace = az.from_dict(post) return trace, post
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) params, losses = svi_result.params, svi_result.losses ref_params = {"theta": params["theta_auto_loc"]} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def fit(self, X, Y, rng_key, n_step): self.X_train = X # store moments of training y (to normalize) self.y_mean = jnp.mean(Y) self.y_std = jnp.std(Y) # normalize y Y = (Y - self.y_mean) / self.y_std # setup optimizer and SVI optim = numpyro.optim.Adam(step_size=0.005, b1=0.5) svi = SVI( model, guide=AutoDelta(model), optim=optim, loss=Trace_ELBO(), X=X, Y=Y, ) params, _ = svi.run(rng_key, n_step) # get kernel parameters from guide with proper names self.kernel_params = svi.guide.median(params) # store cholesky factor of prior covariance self.L = linalg.cho_factor(self.kernel(X, X, **self.kernel_params)) # store inverted prior covariance multiplied by y self.alpha = linalg.cho_solve(self.L, Y) return self.kernel_params
def test_autoguide_deterministic(auto_class): def model(y=None): n = y.size if y is not None else 1 mu = numpyro.sample("mu", dist.Normal(0, 5)) sigma = numpyro.param("sigma", 1, constraint=constraints.positive) y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y) numpyro.deterministic("z", (y - mu) / sigma) mu, sigma = 2, 3 y = mu + sigma * random.normal(random.PRNGKey(0), shape=(300,)) y_train = y[:200] y_test = y[200:] guide = auto_class(model) optimiser = numpyro.optim.Adam(step_size=0.01) svi = SVI(model, guide, optimiser, Trace_ELBO()) params, losses = svi.run(random.PRNGKey(0), num_steps=500, y=y_train) posterior_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(1000,) ) predictive = Predictive(model, posterior_samples, params=params) predictive_samples = predictive(random.PRNGKey(0), y_test) assert predictive_samples["y"].shape == (1000, 100) assert predictive_samples["z"].shape == (1000, 100) assert_allclose( (predictive_samples["y"] - posterior_samples["mu"][..., None]) / params["sigma"], predictive_samples["z"], atol=0.05, )
def test_run(progress_bar): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param( "alpha_q", lambda key: random.normal(key), constraint=constraints.positive ) beta_q = numpyro.param( "beta_q", lambda key: random.exponential(key), constraint=constraints.positive, ) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) assert losses.shape == (1000,) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), 0.8, atol=0.05, rtol=0.05, )
def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8): assert guide_family in ["AutoDiagonalNormal", "AutoDAIS"] if guide_family == "AutoDAIS": guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5) step_size = 5e-4 elif guide_family == "AutoDiagonalNormal": guide = autoguide.AutoDiagonalNormal(model) step_size = 3e-3 optimizer = numpyro.optim.Adam(step_size=step_size) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(rng_key, args.num_svi_steps, X, Y) params = svi_result.params final_elbo = -Trace_ELBO(num_particles=1000).loss(rng_key, params, model, guide, X, Y) guide_name = guide_family if guide_family == "AutoDAIS": guide_name += "-{}".format(K) print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo)) return guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(args.num_samples, ))
def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) with numpyro.plate("plate", 10): numpyro.deterministic("beta_sq", f**2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(1), 3000, data) params = svi_result.params predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None) assert predictive["beta_sq"].shape == (1000, ) obs_pred = predictive["obs"].astype(np.float32) assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
def test_tracegraph_normal_normal(): # normal-normal; known covariance lam0 = jnp.array([0.1, 0.1]) # precision of prior loc0 = jnp.array([0.0, 0.5]) # prior mean # known precision of observation noise lam = jnp.array([6.0, 4.0]) data = [] data.append(jnp.array([-0.1, 0.3])) data.append(jnp.array([0.0, 0.4])) data.append(jnp.array([0.2, 0.5])) data.append(jnp.array([0.1, 0.7])) n_data = len(data) sum_data = data[0] + data[1] + data[2] + data[3] analytic_lam_n = lam0 + n_data * lam analytic_log_sig_n = -0.5 * jnp.log(analytic_lam_n) analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * ( lam0 / analytic_lam_n) class FakeNormal(dist.Normal): reparametrized_params = [] def model(): with numpyro.plate("plate", 2): loc_latent = numpyro.sample( "loc_latent", FakeNormal(loc0, jnp.power(lam0, -0.5))) for i, x in enumerate(data): numpyro.sample( "obs_{}".format(i), dist.Normal(loc_latent, jnp.power(lam, -0.5)), obs=x, ) return loc_latent def guide(): loc_q = numpyro.param("loc_q", analytic_loc_n + jnp.array([0.334, 0.334])) log_sig_q = numpyro.param( "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29])) sig_q = jnp.exp(log_sig_q) with numpyro.plate("plate", 2): loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q)) return loc_latent adam = optim.Adam(step_size=0.0015, b1=0.97, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 5000) loc_error = jnp.sum( jnp.power(analytic_loc_n - svi_result.params["loc_q"], 2.0)) log_sig_error = jnp.sum( jnp.power(analytic_log_sig_n - svi_result.params["log_sig_q"], 2.0)) assert_allclose(loc_error, 0, atol=0.05) assert_allclose(log_sig_error, 0, atol=0.05)
def test_svi_discrete_latent(): cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()] mixed_inf_cls = [TraceGraph_ELBO()] assert not any([c.can_infer_discrete for c in cont_inf_only_cls]) assert all([c.can_infer_discrete for c in mixed_inf_cls]) def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) for elbo in cont_inf_only_cls: svi = SVI(model, guide, optim.Adam(1), elbo) s_name = type(elbo).__name__ w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" with pytest.warns(UserWarning, match=w_msg): svi.run(random.PRNGKey(0), 10)
def run_svi(model, guide_family, args, X, Y): if guide_family == "AutoDelta": guide = autoguide.AutoDelta(model) elif guide_family == "AutoDiagonalNormal": guide = autoguide.AutoDiagonalNormal(model) optimizer = numpyro.optim.Adam(0.001) svi = SVI(model, guide, optimizer, Trace_ELBO()) svi_results = svi.run(PRNGKey(1), args.maxiter, X=X, Y=Y) params = svi_results.params return params, guide
def test_stable_run(stable_run): def model(): var = numpyro.sample("var", dist.Exponential(1)) numpyro.sample("obs", dist.Normal(0, jnp.sqrt(var)), obs=0.0) def guide(): loc = numpyro.param("loc", 0.0) numpyro.sample("var", dist.Normal(loc, 10)) svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_run) assert jnp.isfinite(svi_result.params["loc"]) == stable_run
def test_subsample_model_with_deterministic(): def model(): x = numpyro.sample("x", dist.Normal(0, 1)) numpyro.deterministic("x2", x * 2) with numpyro.plate("N", 10, subsample_size=5): numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones(5)) guide = AutoNormal(model) svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 10) samples = guide.sample_posterior(random.PRNGKey(1), svi_result.params) assert "x2" in samples
def test_laplace_approximation_custom_hessian(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10)) mu = a + b * x numpyro.sample("y", dist.Normal(mu, 1), obs=y) x = random.normal(random.PRNGKey(0), (100, )) y = 1 + 2 * x guide = AutoLaplaceApproximation( model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False) guide.get_transform(svi_result.params)
def test_pickle_autoguide(guide_class): x = np.random.poisson(1.0, size=(100,)) guide = guide_class(poisson_regression) optim = numpyro.optim.Adam(1e-2) svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO()) svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) pickled_guide = pickle.loads(pickle.dumps(guide)) predictive = Predictive( poisson_regression, guide=pickled_guide, params=svi_result.params, num_samples=1, return_sites=["param", "x"], ) samples = predictive(random.PRNGKey(1), None, 1) assert set(samples.keys()) == {"param", "x"}
def test_tracegraph_beta_bernoulli(): # bernoulli-beta model # beta prior hyperparameter alpha0 = 1.0 beta0 = 1.0 # beta prior hyperparameter data = jnp.array([0.0, 1.0, 1.0, 1.0]) n_data = float(len(data)) data_sum = data.sum() alpha_n = alpha0 + data_sum # posterior alpha beta_n = beta0 - data_sum + n_data # posterior beta log_alpha_n = jnp.log(alpha_n) log_beta_n = jnp.log(beta_n) class FakeBeta(dist.Beta): reparametrized_params = [] def model(): p_latent = numpyro.sample("p_latent", FakeBeta(alpha0, beta0)) with numpyro.plate("data", len(data)): numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass return p_latent adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 3000) alpha_error = jnp.sum( jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0)) beta_error = jnp.sum( jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0)) assert_allclose(alpha_error, 0, atol=0.03) assert_allclose(beta_error, 0, atol=0.04)
def test_tracegraph_gamma_exponential(): # exponential-gamma model # gamma prior hyperparameter alpha0 = 1.0 # gamma prior hyperparameter beta0 = 1.0 n_data = 2 data = jnp.array([3.0, 2.0]) # two observations alpha_n = alpha0 + n_data # posterior alpha beta_n = beta0 + data.sum() # posterior beta log_alpha_n = jnp.log(alpha_n) log_beta_n = jnp.log(beta_n) class FakeGamma(dist.Gamma): reparametrized_params = [] def model(): lambda_latent = numpyro.sample("lambda_latent", FakeGamma(alpha0, beta0)) with numpyro.plate("data", len(data)): numpyro.sample("obs", dist.Exponential(lambda_latent), obs=data) return lambda_latent def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 8000) alpha_error = jnp.sum( jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0)) beta_error = jnp.sum( jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0)) assert_allclose(alpha_error, 0, atol=0.04) assert_allclose(beta_error, 0, atol=0.04)
def run_inference(docs, args): rng_key = random.PRNGKey(0) docs = device_put(docs) hyperparams = dict( vocab_size=docs.shape[1], num_topics=args.num_topics, hidden=args.hidden, dropout_rate=args.dropout_rate, batch_size=args.batch_size, ) optimizer = numpyro.optim.Adam(args.learning_rate) svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO()) return svi.run( rng_key, args.num_steps, docs, hyperparams, is_training=True, progress_bar=not args.disable_progbar, nn_framework=args.nn_framework, )
pml.savefig(f'multicollinear_sum_post_{method}.pdf') plt.show() # Laplace fit m6_1 = AutoLaplaceApproximation(model) svi = SVI(model, m6_1, optim.Adam(0.1), Trace_ELBO(), leg_left=df.leg_left.values, leg_right=df.leg_right.values, height=df.height.values, br_positive=False) p6_1, losses = svi.run(random.PRNGKey(0), 2000) post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000, )) analyze_post(post_laplace, 'laplace') # MCMC fit # code from p298 (code 9.28) of rethinking2 #https://fehiepsi.github.io/rethinking-numpyro/09-markov-chain-monte-carlo.html kernel = NUTS( model, init_strategy=init_to_value(values={ "a": 10.0, "bl": 0.0, "br": 0.1, "sigma": 1.0
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)
# =================== n_epochs = 1_000 lr = 0.01 optimizer = Adam(step_size=lr) # =================== # Training # =================== # reproducibility rng_key = random.PRNGKey(42) # setup svi svi = SVI(sgp_model, delta_guide, optimizer, loss=Trace_ELBO()) # run svi svi_results = svi.run(rng_key, n_epochs, X, y.T) # =================== # Plot Loss # =================== fig, ax = plt.subplots(ncols=1, figsize=(6, 4)) ax.plot(svi_results.losses) ax.set(title="Loss", xlabel="Iterations", ylabel="Negative Log-Likelihood") plt.tight_layout() wandb.log({f"loss": [wandb.Image(plt)]}) wandb.log({f"nll_loss": np.array(svi_results.losses[-1])}) learned_params = delta_guide.median(svi_results.params) learned_params["x_u"] = svi_results.params["x_u"] # =================
bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3 = AutoLaplaceApproximation(model) svi = SVI(model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values) p5_3, losses = svi.run(random.PRNGKey(0), 1000) post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, )) # Posterior param_names = {'a', 'bA', 'bM', 'sigma'} for p in param_names: print(f'posterior for {p}') print_summary(post[p], 0.95, False) # PPC # call predictive without specifying new data # so it uses original data post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), )) post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2),
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol, difficulty): loc0 = 0.2 data = jnp.array([-0.1, 0.03, 0.2, 0.1]) n_data = data.shape[0] sum_data = data.sum() N = num_latents lambdas = [1.5 * (k + 1) / N for k in range(N + 1)] lambdas = list(map(lambda x: jnp.array([x]), lambdas)) lambda_tilde_posts = [lambdas[0]] for k in range(1, N): lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / ( lambdas[k] + lambda_tilde_posts[k - 1]) lambda_tilde_posts.append(lambda_tilde_k) lambda_posts = [ None ] # this is never used (just a way of shifting the indexing by 1) for k in range(1, N): lambda_k = lambdas[k] + lambda_tilde_posts[k - 1] lambda_posts.append(lambda_k) lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1] lambda_posts.append(lambda_N_post) target_kappas = [None] target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)]) target_mus = [None] target_mus.extend([ loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k] for k in range(1, N) ]) target_loc_N = (sum_data * lambdas[N] / lambda_N_post + loc0 * lambda_tilde_posts[N - 1] / lambda_N_post) target_mus.append(target_loc_N) np.random.seed(0) while True: mask = np.random.binomial(1, 0.3, (N, )) if mask.sum() < 0.4 * N and mask.sum() > 0.5: which_nodes_reparam = mask break class FakeNormal(dist.Normal): reparametrized_params = [] def model(difficulty=0.0): next_mean = loc0 for k in range(1, N + 1): latent_dist = dist.Normal(next_mean, jnp.power(lambdas[k - 1], -0.5)) loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist) next_mean = loc_latent loc_N = next_mean with numpyro.plate("data", data.shape[0]): numpyro.sample("obs", dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)), obs=data) return loc_N def guide(difficulty=0.0): previous_sample = None for k in reversed(range(1, N + 1)): loc_q = numpyro.param( f"loc_q_{k}", lambda key: target_mus[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) log_sig_q = numpyro.param( f"log_sig_q_{k}", lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty * (0.1 * random.normal(key) - 0.53), ) sig_q = jnp.exp(log_sig_q) kappa_q = None if k != N: kappa_q = numpyro.param( "kappa_q_%d" % k, lambda key: target_kappas[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False Normal = dist.Normal if node_flagged else FakeNormal loc_latent = numpyro.sample(f"loc_latent_{k}", Normal(mean_function, sig_q)) previous_sample = loc_latent return previous_sample adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), num_steps, difficulty=difficulty) kappa_errors, log_sig_errors, loc_errors = [], [], [] for k in range(1, N + 1): if k != N: kappa_error = jnp.sum( jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k], 2)) kappa_errors.append(kappa_error) loc_errors.append( jnp.sum( jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2))) log_sig_error = jnp.sum( jnp.power( svi_result.params[f"log_sig_q_{k}"] + 0.5 * jnp.log(lambda_posts[k]), 2)) log_sig_errors.append(log_sig_error) max_errors = (np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors)) for i in range(3): assert_allclose(max_errors[i], 0, atol=atol)
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO()) print("Start training guide...") svi_result = svi.run(random.PRNGKey(1), args.num_iters) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), svi_result.params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, svi_result.params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(svi_result.losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
def test_cond(): def model(): def true_fun(_): x = numpyro.sample("x", dist.Normal(4.0)) numpyro.deterministic("z", x - 4.0) def false_fun(_): x = numpyro.sample("x", dist.Normal(0.0)) numpyro.deterministic("z", x) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None) def guide(): m1 = numpyro.param("m1", 2.0) s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) m2 = numpyro.param("m2", 2.0) s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) def true_fun(_): numpyro.sample("x", dist.Normal(m1, s1)) def false_fun(_): numpyro.sample("x", dist.Normal(m2, s2)) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None) svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) svi_result = svi.run(random.PRNGKey(0), num_steps=2500) params = svi_result.params predictive = Predictive( model, guide=guide, params=params, num_samples=1000, return_sites=["cluster", "x", "z"], ) result = predictive(random.PRNGKey(0)) assert result["cluster"].shape == (1000,) assert result["x"].shape == (1000,) assert result["z"].shape == (1000,) mcmc = MCMC( NUTS(model), num_warmup=500, num_samples=2500, num_chains=4, chain_method="sequential", ) mcmc.run(random.PRNGKey(0)) x = mcmc.get_samples()["x"] assert x.shape == (10_000,) assert_allclose( [x[x > 2.0].mean(), x[x > 2.0].std(), x[x < 2.0].mean(), x[x < 2.0].std()], [4.01, 0.965, -0.01, 0.965], atol=0.1, ) assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5)
plt.show() # Laplace fit m6_1 = AutoLaplaceApproximation(model) svi = SVI( model, m6_1, optim.Adam(0.1), Trace_ELBO(), leg_left=df.leg_left.values, leg_right=df.leg_right.values, height=df.height.values, br_positive=False ) svi_run = svi.run(random.PRNGKey(0), 2000) p6_1 = svi_run.params losses = svi_run.losses post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,)) analyze_post(post_laplace, 'laplace') # MCMC fit # code from p298 (code 9.28) of rethinking2 #https://fehiepsi.github.io/rethinking-numpyro/09-markov-chain-monte-carlo.html kernel = NUTS( model, init_strategy=init_to_value(values={"a": 10.0, "bl": 0.0, "br": 0.1, "sigma": 1.0}),