コード例 #1
0
def test_mcmc_diagnostics(num_chains):
    data = torch.tensor([2.0]).repeat(3)
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data,),
                                                        num_chains=num_chains)
    kernel = PriorKernel(normal_normal_model)
    mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn",
                initial_params=initial_params, transforms=transforms)
    mcmc.run(data)
    if not torch.backends.mkl.is_available():
        pytest.skip()
    diagnostics = mcmc.diagnostics()
    assert diagnostics["y"]["n_eff"].shape == data.shape
    assert diagnostics["y"]["r_hat"].shape == data.shape
    assert diagnostics["dummy_key"] == {'chain {}'.format(i): 'dummy_value'
                                        for i in range(num_chains)}
コード例 #2
0
def test_save_params(save_params, Kernel, options):
    save_params = list(save_params)

    def model():
        x = pyro.sample("x", dist.Normal(0, 1))
        with pyro.plate("plate", 2):
            y = pyro.sample("y", dist.Normal(x, 1))
            pyro.sample("obs", dist.Normal(y, 1), obs=torch.zeros(2))

    kernel = Kernel(model, **options)
    mcmc = MCMC(kernel, warmup_steps=2, num_samples=4, save_params=save_params)
    mcmc.run()

    samples = mcmc.get_samples()
    assert set(samples.keys()) == set(save_params)

    diagnostics = mcmc.diagnostics()
    diagnostics = {k: v for k, v in diagnostics.items() if k in "xy"}
    assert set(diagnostics.keys()) == set(save_params)

    mcmc.summary()  # smoke test
コード例 #3
0
def test_mcmc_diagnostics(run_mcmc_cls, num_chains):
    data = torch.tensor([2.0]).repeat(3)
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data, ),
                                                        num_chains=num_chains)
    kernel = PriorKernel(normal_normal_model)
    if run_mcmc_cls == run_default_mcmc:
        mcmc = MCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            mp_context="spawn",
            initial_params=initial_params,
            transforms=transforms,
        )
    else:
        mcmc = StreamingMCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            initial_params=initial_params,
            transforms=transforms,
        )
    mcmc.run(data)
    if not torch.backends.mkl.is_available():
        pytest.skip()
    diagnostics = mcmc.diagnostics()
    if run_mcmc_cls == run_default_mcmc:  # TODO n_eff for streaming MCMC
        assert diagnostics["y"]["n_eff"].shape == data.shape
    assert diagnostics["y"]["r_hat"].shape == data.shape
    assert diagnostics["dummy_key"] == {
        "chain {}".format(i): "dummy_value"
        for i in range(num_chains)
    }
コード例 #4
0
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    init_params, potential_fn, transforms, _ = initialize_model(fully_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(summary(samples_fully_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_log_posterior_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    init_params, potential_fn, transforms, _ = initialize_model(not_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(summary(samples_not_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_log_posterior_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    init_params, potential_fn, transforms, _ = initialize_model(partially_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)

    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(summary(samples_partially_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_log_posterior_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    init_params, potential_fn, transforms, _ = initialize_model(partially_pooled_with_logit,
                                                                model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(summary(samples_partially_pooled_logit,
                         sites=["alpha"],
                         player_names=player_names,
                         transforms={"alpha": torch.sigmoid},
                         diagnostics=diagnostics)["alpha"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_log_posterior_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                   baseball_dataset)