def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: warmup_steps, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: warmup_steps, num_samples = (2000, 12000) else: warmup_steps, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) # those coefficients are found by doing MAP inference using AutoDelta expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.1) if 'JAX_ENABLE_X64' in os.environ: assert samples['coefs'].dtype == jnp.float64
def run_inference(model, args, rng_key): kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key) mcmc.print_summary() return mcmc.get_samples()
def sample(model, num_samples, num_warmup, num_chains, seed=0, chain_method="parallel", summary=True, **kwargs): """Run the No-U-Turn sampler """ rng_key = random.PRNGKey(seed) kernel = NUTS(model) # Note: sampling more than one chain doesn't show a progress bar mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, chain_method=chain_method) mcmc.run(rng_key, **kwargs) if summary: mcmc.print_summary() # Return a fitted MCMC object return mcmc
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) params, losses = svi_result.params, svi_result.losses ref_params = {"theta": params["theta_auto_loc"]} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def sample_model(rng_key, model, model_args_dict, num_warmup=500, num_samples=500, num_chains=1): kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, progress_bar=True) mcmc.run(rng_key, **model_args_dict) mcmc.print_summary() # divergences = mcmc.get_extra_fields()["diverging"] samples = mcmc.get_samples() # samples['divergences'] = divergences if not 'b_condition' in samples: bC = numpyro.infer.Predictive(model, samples).get_samples( rng_key, **model_args_dict) samples['b_condition'] = bC return samples
def run_inference(model, inputs, method=None): if method is None: # NUTS num_samples = 5000 logger.info('NUTS sampling') kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples) rng_key = random.PRNGKey(0) mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', )) logger.info(r'MCMC summary for: {}'.format(model.__name__)) mcmc.print_summary(exclude_deterministic=False) samples = mcmc.get_samples() else: #SVI logger.info('Guide generation...') rng_key = random.PRNGKey(0) guide = AutoDiagonalNormal(model=model) logger.info('Optimizer generation...') optim = Adam(0.05) logger.info('SVI generation...') svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs) init_state = svi.init(rng_key) logger.info('Scan...') state, loss = lax.scan(lambda x, i: svi.update(x), init_state, np.zeros(2000)) params = svi.get_params(state) samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, )) logger.info(r'SVI summary for: {}'.format(model.__name__)) numpyro.diagnostics.print_summary(samples, prob=0.90, group_by_chain=False) return samples
def infer(self, num_warmup=1000, num_samples=1000, num_chains=1, rng_key=PRNGKey(1), **args): '''Fit using MCMC''' # Start from this source of randomness. We will split keys for subsequent operations. rng_key = PRNGKey(0) rng_key, rng_key_ = split(rng_key) args = dict(self.args, **args) #kernel = NUTS(self, init_strategy = numpyro.infer.util.init_to_median()) kernel = NUTS(self) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(rng_key, **self.obs, **args) mcmc.print_summary() self.mcmc = mcmc self.mcmc_samples = mcmc.get_samples() return self.mcmc_samples
def run_inference(model): kernel = NUTS(model) rng_key = random.PRNGKey(0) mcmc = MCMC(kernel, num_warmup = 500, num_samples = 500, num_chains = 1) mcmc.run(rng_key) mcmc.print_summary(exclude_deterministic=False) return mcmc.get_samples()
def main(args): model = models[args.model] _, fetch = load_dataset(JSB_CHORALES, split='train', shuffle=False) lengths, sequences = fetch() if args.num_sequences: sequences = sequences[0:args.num_sequences] lengths = lengths[0:args.num_sequences] logger.info('-' * 40) logger.info('Training {} on {} sequences'.format( model.__name__, len(sequences))) # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes with default args) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clip(0, args.truncate) sequences = sequences[:, :args.truncate] logger.info('Each sequence has shape {}'.format(sequences[0].shape)) logger.info('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, sequences, lengths, args=args) mcmc.print_summary() logger.info('\nMCMC elapsed time: {}'.format(time.time() - start))
def run_inference(model, args, rng_key, X, Y): start = time.time() # demonstrate how to use different HMC initialization strategies if args.init_strategy == "value": init_strategy = init_to_value(values={ "kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5 }) elif args.init_strategy == "median": init_strategy = init_to_median(num_samples=10) elif args.init_strategy == "feasible": init_strategy = init_to_feasible() elif args.init_strategy == "sample": init_strategy = init_to_sample() elif args.init_strategy == "uniform": init_strategy = init_to_uniform(radius=1) kernel = NUTS(model, init_strategy=init_strategy) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def fit(self, df, iter=500, seed=42, **kwargs): teams = sorted(list(set(df["home_team"]) | set(df["away_team"]))) home_team = df["home_team"].values away_team = df["away_team"].values home_goals = df["home_goals"].values away_goals = df["away_goals"].values gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values self.team_to_index = {team: i for i, team in enumerate(teams)} self.index_to_team = { value: key for key, value in self.team_to_index.items() } self.n_teams = len(teams) self.min_date = df["date"].min() conditioned_model = condition(self.model, param_map={ "home_goals": home_goals, "away_goals": away_goals }) nuts_kernel = NUTS(conditioned_model) mcmc = MCMC(nuts_kernel, num_warmup=iter // 2, num_samples=iter, **kwargs) rng_key = random.PRNGKey(seed) mcmc.run(rng_key, home_team, away_team, gameweek) self.samples = mcmc.get_samples() mcmc.print_summary() return self
def test_logistic_regression(): from tensorflow_probability.substrates.jax import distributions as tfd N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = tfd.Bernoulli(logits=logits).sample(seed=random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", tfd.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples["logits"].shape == (num_samples, N) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.22)
def test_logistic_regression(): from numpyro.contrib.tfp import distributions as dist N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.22)
def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1.0, 0.5 num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000) def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std)**2) init_params = jnp.array(0.0) if kernel_cls is SA: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) elif kernel_cls is BarkerMH: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) else: kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07) assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07) if "JAX_ENABLE_X64" in os.environ: assert hmc_states.dtype == jnp.float64
def inference( model: Callable, num_categories: int, num_words: int, supervised_categories: jnp.ndarray, supervised_words: jnp.ndarray, unsupervised_words: jnp.ndarray, rng_key: np.ndarray, *, num_warmup: int = 500, num_samples: int = 1000, num_chains: int = 1, verbose: bool = True, ) -> Dict[str, jnp.ndarray]: kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run( rng_key, num_categories, num_words, supervised_categories, supervised_words, unsupervised_words, ) if verbose: mcmc.print_summary() return mcmc.get_samples()
def main(args): _, fetch = load_dataset(LYNXHARE, shuffle=False) year, data = fetch() # data is in hare -> lynx order # use dense_mass for better mixing rate mcmc = MCMC(NUTS(model, dense_mass=True), args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data)) mcmc.print_summary() # predict populations y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] pop_pred = jnp.exp(y_pred) mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend() plt.savefig("ode_plot.pdf") plt.tight_layout()
def test_beta_bernoulli_x64(kernel_cls): warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000) def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) if kernel_cls is SA: kernel = SA(model=model) else: kernel = kernel_cls(model=model, trajectory_length=1.) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def run_hmc(mcmc_key, args, data, obs, kernel): mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, None) mcmc.print_summary() return mcmc.get_samples()
def main(args): annotators, annotations = get_data() model = NAME_TO_MODEL[args.model] data = ((annotations, ) if model in [multinomial, item_difficulty] else (annotators, annotations)) mcmc = MCMC( NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(0), *data) mcmc.print_summary() posterior_samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples, infer_discrete=True) discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(discrete_samples["c"].squeeze(-1)) print("Histogram of the predicted class of each item:") row_format = "{:>10}" * 5 print(row_format.format("", *["c={}".format(i) for i in range(4)])) for i, row in enumerate(item_class): print(row_format.format(f"item[{i}]", *row))
def run_nuts(mcmc_key, args, X, Y): mcmc = MCMC(NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, X, Y) mcmc.print_summary() return mcmc.get_samples()
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = numpyro.deterministic('logits', np.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) else: kernel = kernel_cls(model=model, trajectory_length=8) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22) if 'JAX_ENABLE_X64' in os.environ: assert samples['coefs'].dtype == np.float64
def run_inference(args, data): print("=== Performing Nested Sampling ===") ns = NestedSampler(model) ns.run(random.PRNGKey(0), **data, enum=args.enum) # TODO: Remove this condition when jaxns is compatible with the latest jax version. if jax.__version__ < "0.2.21": ns.print_summary() # samples obtained from nested sampler are weighted, so # we need to provide random key to resample from those weighted samples ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples) print("\n=== Performing MCMC Sampling ===") if args.enum: mcmc = MCMC(NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples) else: mcmc = MCMC( DiscreteHMCGibbs(NUTS(model)), num_warmup=args.num_warmup, num_samples=args.num_samples, ) mcmc.run(random.PRNGKey(2), **data, enum=args.enum) mcmc.print_summary() mcmc_samples = mcmc.get_samples() return ns_samples["x"], mcmc_samples["x"]
class NutsHandler(Handler): def __init__( self, model, posterior=None, num_warmup=2000, num_samples=10000, num_chains=1, key=0, *args, **kwargs, ): self.model = model self.rng_key, self.rng_key_ = random.split(random.PRNGKey(key)) if posterior is not None: self.mcmc = posterior self.posterior = self.mcmc.get_samples() else: self.kernel = NUTS(model, **kwargs) self.mcmc = MCMC(self.kernel, num_warmup, num_samples, num_chains=num_chains) def _select(self, which): assert which in [ "prior", "posterior", "posterior_predictive", ], "Please select from 'prior', 'posterior' or 'posterior_predictive'." assert hasattr(self, which), f"NutsHandler did not compute the {which} yet." return getattr(self, which) def get_prior(self, *args, **kwargs): predictive = Predictive(self.model, num_samples=self.mcmc.num_samples) self.prior = predictive(self.rng_key_, *args, **kwargs) def get_posterior_predictive(self, *args, **kwargs): predictive = Predictive(self.model, self.posterior, **kwargs) self.posterior_predictive = predictive(self.rng_key_, *args) def fit(self, *args, **kwargs): self.mcmc.run(self.rng_key_, *args, **kwargs) self.posterior = self.mcmc.get_samples() def summary(self, *args, **kwargs): self.mcmc.print_summary(*args, **kwargs) def dump(self, path): with open(path, "wb") as f: dill.dump(self.mcmc, f) @staticmethod def from_dump(model, path): with open(path, "rb") as f: posterior = dill.load(f) return NutsHandler(model, posterior=posterior)
def run_inference(model, args, rng_key, X, Y, D_H): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng_key, X, Y, D_H) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def run_inference(model, args, rng_key, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def test_random_module_mcmc(backend, init): if backend == "flax": import flax linear_module = flax.linen.Dense(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module kwargs_name = "inputs" elif backend == "haiku": import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(1)(x)) bias_name = "linear.b" weight_name = "linear.w" random_module = random_haiku_module kwargs_name = "x" N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1.0, dim + 1.0) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) if init == "shape": kwargs = {"input_shape": (3,)} elif init == "kwargs": kwargs = {kwargs_name: data} def model(data, labels): nn = random_module( "nn", linear_module, {bias_name: dist.Cauchy(), weight_name: dist.Normal()}, **kwargs ) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() assert set(samples.keys()) == { "nn/{}".format(bias_name), "nn/{}".format(weight_name), } assert_allclose( np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22, )
def run_inference(model, capture_history, sex, rng_key, args): if args.algo == "NUTS": kernel = NUTS(model) elif args.algo == "HMC": kernel = HMC(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, capture_history, sex) mcmc.print_summary() return mcmc.get_samples()
def benchmark_hmc(args, features, labels): step_size = np.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps rng_key = random.PRNGKey(1) start = time.time() kernel = NUTS(model, trajectory_length=trajectory_length) mcmc = MCMC(kernel, 0, args.num_samples) mcmc.run(rng_key, features, labels) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start)
def run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs): kernel = NUTS(model, init_strategy=init_to_value(values=bvm_init_locs), max_tree_depth=7) mcmc = MCMC(kernel, num_samples=args.num_samples, num_warmup=args.num_warmup) mcmc.run(rng_key, data, len(data), num_mix_comp) mcmc.print_summary() post_samples = mcmc.get_samples() return post_samples
def run_mcmc(model, args, X, Y): kernel = NUTS(model) mcmc = MCMC( kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(PRNGKey(1), X, Y) mcmc.print_summary() return mcmc.get_samples()