def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1., 0.5 warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std) ** 2) init_params = np.array(0.) kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) hmc_states = mcmc.get_samples() assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert hmc_states.dtype == np.float64
def test_reuse_mcmc_run(jit_args, shape): y1 = np.random.normal(3, 0.1, (100,)) y2 = np.random.normal(-3, 0.1, (shape,)) def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) # Run MCMC on zero observations. kernel = NUTS(model) mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args) mcmc.run(random.PRNGKey(32), y1) # Re-run on new data - should be much faster. mcmc.run(random.PRNGKey(32), y2) assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
def test_predictive_with_improper(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.param('loc', 0., constraint=constraints.interval(0., alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"] assert_allclose(np.mean(obs_pred), true_coef, atol=0.05)
def test_improper_prior(): true_mean, true_std = 1., 2. num_warmup, num_samples = 1000, 8000 def model(data): mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False)) std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ())) return numpyro.sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(2), data) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(jnp.mean(samples['std']), true_std, rtol=0.05)
def test_prior_with_sample_shape(): data = { "J": 8, "y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]), "sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]), } def schools_model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],)) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) num_samples = 500 mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples) mcmc.run(random.PRNGKey(0)) assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
def test_improper_normal(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
def inference(belief_sequences, obs, mask, rng_key): nuts_kernel = NUTS(model, dense_mass=True) mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False) mcmc.run( rng_key, belief_sequences, obs, mask, extra_fields=('potential_energy',) ) samples = mcmc.get_samples() potential_energy = mcmc.get_extra_fields()['potential_energy'].mean() # mcmc.print_summary() return samples, potential_energy
def run_inference(model, at_bats, hits, rng_key, args): if args.algo == "NUTS": kernel = NUTS(model) elif args.algo == "HMC": kernel = HMC(model) elif args.algo == "SA": kernel = SA(model) mcmc = MCMC( kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if ("NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True, ) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob): kernel = kernel_cls( model, step_size=step_size, trajectory_length=trajectory_length, target_accept_prob=target_accept_prob, ) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=2, chain_method=chain_method, progress_bar=False, ) mcmc.run(rng_key, data) return mcmc.get_samples()
def all_bands(priors, num_samples=500, num_warmup=500, num_chains=4, chain_method='parallel'): numpyro.set_host_device_count(4) nuts_kernel = NUTS(spire_model) mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, num_chains=num_chains, chain_method=chain_method) rng_key = random.PRNGKey(0) mcmc.run(rng_key, priors, extra_fields=( 'potential_energy', 'energy', )) return mcmc
def test_initial_inverse_mass_matrix_ndarray(dense_mass): def model(): numpyro.sample("z", dist.Normal(0, 1).expand([2])) numpyro.sample("x", dist.Normal(0, 1).expand([3])) expected_mm = jnp.arange(1, 6.0) kernel = NUTS( model, dense_mass=dense_mass, inverse_mass_matrix=expected_mm, adapt_mass_matrix=False, ) mcmc = MCMC(kernel, num_warmup=1, num_samples=1) mcmc.run(random.PRNGKey(0)) inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix assert set(inverse_mass_matrix.keys()) == {("x", "z")} expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm assert_allclose(inverse_mass_matrix[("x", "z")], expected_mm)
def sample_posterior_with_predictive(rng_key: random.PRNGKey, model, data: np.ndarray, Nsamples: int = 1000, alpha: float = 1, sigma: float = 0, T: int = 10): kernel = NUTS(model) mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP) mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T) samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples=samples, return_sites=["z"]) return predictive(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)["z"]
def test_discrete_gibbs_multiple_sites_chain(kernel, inner_kernel, kwargs, num_chains): def model(): numpyro.sample("x", dist.Bernoulli(0.7).expand([3])) numpyro.sample("y", dist.Binomial(10, 0.3)) sampler = kernel(inner_kernel(model), **kwargs) mcmc = MCMC( sampler, num_warmup=1000, num_samples=10000, num_chains=num_chains, progress_bar=False, ) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01) assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
def test_diverging(kernel_cls, adapt_step_size): data = random.normal(random.PRNGKey(0), (1000,)) def model(data): loc = numpyro.sample('loc', dist.Normal(0., 1.)) numpyro.sample('obs', dist.Normal(loc, 1), obs=data) kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False) num_warmup = num_samples = 1000 mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True) warmup_divergences = mcmc.get_extra_fields()['diverging'].sum() mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging']) num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum() if adapt_step_size: assert num_divergences <= num_warmup else: assert_allclose(num_divergences, num_warmup + num_samples)
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, args.unroll_loop) samples = mcmc.get_samples() print_results(samples, transition_prob, emission_prob) print('\nMCMC elapsed time:', time.time() - start) # make plots fig, ax = plt.subplots(1, 1) x = np.linspace(0, 1, 101) for i in range(transition_prob.shape[0]): for j in range(transition_prob.shape[1]): ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x), label="trans_prob[{}, {}], true value = {:.2f}".format( i, j, transition_prob[i, j])) ax.set(xlabel="Probability", ylabel="Frequency", title="Transition probability posterior") ax.legend() plt.savefig("hmm_plot.pdf") plt.tight_layout()
def test_scan(): def model(T=10, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x0, mu0 = state x1 = numpyro.sample("x", dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample("y", dist.Normal(mu1, r)) numpyro.deterministic("y2", y1 * 2) return (x1, mu1), (x1, y1) mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q)) y0 = numpyro.sample("y_0", dist.Normal(mu0, r)) _, xy = scan(transition, (x0, mu0), jnp.arange(T)) x, y = xy return jnp.append(x0, x), jnp.append(y0, y) T = 10 num_samples = 100 kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples) mcmc.run(random.PRNGKey(0), T=T) assert set(mcmc.get_samples()) == {"x", "y", "y2", "x_0", "y_0"} mcmc.print_summary() samples = mcmc.get_samples() x = samples.pop("x")[0] # take 1 sample of x # this tests for the composition of condition and substitute # this also tests if we can use `vmap` for predictive. future = 5 predictive = Predictive( numpyro.handlers.condition(model, {"x": x}), samples, return_sites=["x", "y", "y2"], parallel=True, ) result = predictive(random.PRNGKey(1), T=T + future) expected_shape = (num_samples, T + future) assert result["x"].shape == expected_shape assert result["y"].shape == expected_shape assert result["y2"].shape == expected_shape assert_allclose(result["x"][:, :T], jnp.broadcast_to(x, (num_samples, T))) assert_allclose(result["y"][:, :T], samples["y"])
def fit_numpyro(progress_bar=False, model=None, num_warmup=1000, n_draws=200, num_chains=4, sampler=NUTS, use_gpu=False, **kwargs): if 'bayes_window_test_mode' in os.environ: # Override settings with minimal use_gpu = False num_warmup = 5 n_draws = 5 num_chains = 1 select_device(use_gpu, num_chains) model = model or models.model_hierarchical mcmc = MCMC( sampler( model=model, find_heuristic_step_size=True, target_accept_prob=0.99, # init_strategy=numpyro.infer.init_to_uniform ), num_warmup=num_warmup, num_samples=n_draws, num_chains=num_chains, progress_bar=progress_bar, chain_method='parallel') mcmc.run(jax.random.PRNGKey(16), **kwargs) # arviz convert try: trace = az.from_numpyro(mcmc) except AttributeError: trace = az.from_dict(mcmc.get_samples()) print(trace.posterior) # Print diagnostics if 'sample_stats' in trace: if trace.sample_stats.diverging.sum(['chain', 'draw']).values > 0: print( f"n(Divergences) = {trace.sample_stats.diverging.sum(['chain', 'draw']).values}" ) return trace, mcmc
def infer(self, num_warmup=1000, num_samples=1000, num_chains=1, rng_key=PRNGKey(1), **args): '''Fit using MCMC''' args = dict(self.args, **args) kernel = NUTS(self, init_strategy = numpyro.infer.util.init_to_median()) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(rng_key, **self.obs, **args) mcmc.print_summary() self.mcmc = mcmc self.mcmc_samples = mcmc.get_samples() return self.mcmc_samples
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray, rng_key: jnp.ndarray, num_warmup: int, num_samples: int, num_chains: int, interval_size: float = 0.95) -> None: """ Estimate the effect size. """ kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, design_matrix, outcome) # 0th column is intercept (not getting called) # 1st column is effect of getting called # 2nd column is effect of gender (should be none since assigned at random) coef = mcmc.get_samples()['coefficients'] print_results(coef, interval_size)
def test_initial_inverse_mass_matrix(dense_mass): def model(): numpyro.sample("x", dist.Normal(0, 1).expand([3])) numpyro.sample("z", dist.Normal(0, 1).expand([2])) expected_mm = jnp.arange(1, 4.0) kernel = NUTS( model, dense_mass=dense_mass, inverse_mass_matrix={("x",): expected_mm}, adapt_mass_matrix=False, ) mcmc = MCMC(kernel, 1, 1) mcmc.run(random.PRNGKey(0)) inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix assert set(inverse_mass_matrix.keys()) == {("x",), ("z",)} expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm assert_allclose(inverse_mass_matrix[("x",)], expected_mm) assert_allclose(inverse_mass_matrix[("z",)], jnp.ones(2))
def _run_inference(self, rng_key=None, X=None, Y=None): ''' Run inference on the model specified above with the supplied data ''' if rng_key is None: rng_key = random.PRNGKey(self.random_state) if self.num_chains > 1: rng_key_ = random.split(rng_key, self.num_chains) else: rng_key, rng_key_ = random.split(rng_key) # The following samples parameter settings with NUTS and MCMC to fit the posterior based on the provided data (X,Y) start = time.time() kernel = NUTS(self._model) mcmc = MCMC(kernel, self.num_warmup, self.num_samples) mcmc.run(rng_key_, X=X, Y=Y) print('/n MCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def test_chain_smoke(chain_method, compile_args): def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent data = dist.Categorical(np.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model) mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args) mcmc.warmup(random.PRNGKey(0), data) mcmc.run(random.PRNGKey(1), data)
def mcmc_inference(model, num_warmup, num_samples, num_chains, rng_key, X, Y): """" Helper function for doing NUTS inference. :param model: a parametric function proportional to the posterior (see gp_regression.likelihood). :param num_warmup: warmup steps. :param num_samples: number of samples. :param num_chains: number of Markov chains used for MCMC sampling. :param rng_key: random seed. :param X: X data. :param Y: Y data. :return: Dictionary key: name of parameter (from defined in model), value: list of samples. """ start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains) mcmc.run(rng_key, X, Y) print('\nMCMC time:', time.time() - start) print(mcmc.print_summary()) return mcmc.get_samples()
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples['loc']) == num_warmup assert len(samples['loc']) == num_samples assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_dirichlet_categorical_x64(kernel_cls, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_bernoulli_latent_model(): def model(data): y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.)) with numpyro.plate("data", data.shape[0]): y = numpyro.sample("y", dist.Bernoulli(y_prob)) z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data) N = 2000 y_prob = 0.3 y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N,)) z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1)) data = dist.Normal(2. * z, 1.0).sample(random.PRNGKey(2)) nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
def test_predictive(parallel): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive = Predictive(model, samples, parallel=parallel) predictive_samples = predictive(random.PRNGKey(1)) assert predictive_samples.keys() == {"beta_sq", "obs"} predictive.return_sites = ["beta", "beta_sq", "obs"] predictive_samples = predictive(random.PRNGKey(1)) # check shapes assert predictive_samples["beta"].shape == (100, ) + true_probs.shape assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape assert predictive_samples["obs"].shape == (100, ) + data.shape # check sample mean obs = predictive_samples["obs"].reshape((-1, ) + true_probs.shape).astype( np.float32) assert_allclose(obs.mean(0), true_probs, rtol=0.1)
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: num_warmup, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: num_warmup, num_samples = (2000, 12000) else: num_warmup, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples["logits"].shape == (num_samples, N) # those coefficients are found by doing MAP inference using AutoDelta expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1) if "JAX_ENABLE_X64" in os.environ: assert samples["coefs"].dtype == jnp.float64
def test_beta_bernoulli(): from numpyro.contrib.tfp import distributions as dist warmup_steps, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2)) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print_results(samples, transition_prob, emission_prob) print('\nMCMC elapsed time:', time.time() - start)