def test_mcmc_diagnostics(num_chains): data = torch.tensor([2.0]).repeat(3) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn", initial_params=initial_params, transforms=transforms) mcmc.run(data) if not torch.backends.mkl.is_available(): pytest.skip() diagnostics = mcmc.diagnostics() assert diagnostics["y"]["n_eff"].shape == data.shape assert diagnostics["y"]["r_hat"].shape == data.shape assert diagnostics["dummy_key"] == {'chain {}'.format(i): 'dummy_value' for i in range(num_chains)}
def test_save_params(save_params, Kernel, options): save_params = list(save_params) def model(): x = pyro.sample("x", dist.Normal(0, 1)) with pyro.plate("plate", 2): y = pyro.sample("y", dist.Normal(x, 1)) pyro.sample("obs", dist.Normal(y, 1), obs=torch.zeros(2)) kernel = Kernel(model, **options) mcmc = MCMC(kernel, warmup_steps=2, num_samples=4, save_params=save_params) mcmc.run() samples = mcmc.get_samples() assert set(samples.keys()) == set(save_params) diagnostics = mcmc.diagnostics() diagnostics = {k: v for k, v in diagnostics.items() if k in "xy"} assert set(diagnostics.keys()) == set(save_params) mcmc.summary() # smoke test
def test_mcmc_diagnostics(run_mcmc_cls, num_chains): data = torch.tensor([2.0]).repeat(3) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data, ), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) if run_mcmc_cls == run_default_mcmc: mcmc = MCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn", initial_params=initial_params, transforms=transforms, ) else: mcmc = StreamingMCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, initial_params=initial_params, transforms=transforms, ) mcmc.run(data) if not torch.backends.mkl.is_available(): pytest.skip() diagnostics = mcmc.diagnostics() if run_mcmc_cls == run_default_mcmc: # TODO n_eff for streaming MCMC assert diagnostics["y"]["n_eff"].shape == data.shape assert diagnostics["y"]["r_hat"].shape == data.shape assert diagnostics["dummy_key"] == { "chain {}".format(i): "dummy_value" for i in range(num_chains) }
def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] logging.info("Original Dataset:") logging.info(baseball_dataset) # (1) Full Pooling Model init_params, potential_fn, transforms, _ = initialize_model(fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_fully_pooled = mcmc.get_samples() logging.info("\nModel: Fully Pooled") logging.info("===================") logging.info("\nphi:") logging.info(summary(samples_fully_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset) evaluate_log_posterior_density(fully_pooled, samples_fully_pooled, baseball_dataset) # (2) No Pooling Model init_params, potential_fn, transforms, _ = initialize_model(not_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_not_pooled = mcmc.get_samples() logging.info("\nModel: Not Pooled") logging.info("=================") logging.info("\nphi:") logging.info(summary(samples_not_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset) evaluate_log_posterior_density(not_pooled, samples_not_pooled, baseball_dataset) # (3) Partially Pooled Model init_params, potential_fn, transforms, _ = initialize_model(partially_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_partially_pooled = mcmc.get_samples() logging.info("\nModel: Partially Pooled") logging.info("=======================") logging.info("\nphi:") logging.info(summary(samples_partially_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset) evaluate_log_posterior_density(partially_pooled, samples_partially_pooled, baseball_dataset) # (4) Partially Pooled with Logit Model init_params, potential_fn, transforms, _ = initialize_model(partially_pooled_with_logit, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_partially_pooled_logit = mcmc.get_samples() logging.info("\nModel: Partially Pooled with Logit") logging.info("==================================") logging.info("\nSigmoid(alpha):") logging.info(summary(samples_partially_pooled_logit, sites=["alpha"], player_names=player_names, transforms={"alpha": torch.sigmoid}, diagnostics=diagnostics)["alpha"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset) evaluate_log_posterior_density(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset)