Ejemplo n.º 1
0
def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chains):
    num_warmup, num_samples = 10, 10
    initial_params, potential_fn, transforms, _ = initialize_model(
        model, num_chains=num_chains)

    iters = []
    hook = partial(_hook, iters)

    mp_context = "spawn" if "CUDA_TEST" in os.environ else None

    kern = kernel(potential_fn=potential_fn,
                  transforms=transforms,
                  jit_compile=jit)
    samples, _ = run_mcmc_cls(
        data=None,
        kernel=kern,
        num_samples=num_samples,
        warmup_steps=num_warmup,
        initial_params=initial_params,
        hook_fn=hook,
        num_chains=num_chains,
        mp_context=mp_context,
    )
    assert samples == {}
    if num_chains == 1:
        expected = [("Warmup", i)
                    for i in range(num_warmup)] + [("Sample", i)
                                                   for i in range(num_samples)]
        assert iters == expected
Ejemplo n.º 2
0
def test_mcmc_interface(num_draws, group_by_chain, num_chains):
    num_samples = 2000
    data = torch.tensor([1.0])
    initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,),
                                                        num_chains=num_chains)
    kernel = PriorKernel(normal_normal_model)
    mcmc = MCMC(kernel=kernel, num_samples=num_samples, warmup_steps=100, num_chains=num_chains,
                mp_context="spawn", initial_params=initial_params, transforms=transforms)
    mcmc.run(data)
    samples = mcmc.get_samples(num_draws, group_by_chain=group_by_chain)
    # test sample shape
    expected_samples = num_draws if num_draws is not None else num_samples
    if group_by_chain:
        expected_shape = (mcmc.num_chains, expected_samples, 1)
    elif num_draws is not None:
        # FIXME: what is the expected behavior of num_draw is not None and group_by_chain=False?
        expected_shape = (expected_samples, 1)
    else:
        expected_shape = (mcmc.num_chains * expected_samples, 1)
    assert samples['y'].shape == expected_shape

    # test sample stats
    if group_by_chain:
        samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
    sample_mean = samples['y'].mean()
    sample_std = samples['y'].std()
    assert_close(sample_mean, torch.tensor(0.0), atol=0.05)
    assert_close(sample_std, torch.tensor(1.0), atol=0.05)
Ejemplo n.º 3
0
def test_num_chains(num_chains, cpu_count, default_init_params, monkeypatch):
    monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: cpu_count)
    data = torch.tensor([1.0])
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data, ),
                                                        num_chains=num_chains)
    if default_init_params:
        initial_params = None
    kernel = PriorKernel(normal_normal_model)
    available_cpu = max(1, cpu_count - 1)
    mp_context = "spawn"
    with optional(pytest.warns(UserWarning), available_cpu < num_chains):
        mcmc = MCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            initial_params=initial_params,
            transforms=transforms,
            mp_context=mp_context,
        )
    mcmc.run(data)
    assert mcmc.num_chains == num_chains
    if mcmc.num_chains == 1 or available_cpu < num_chains:
        assert isinstance(mcmc.sampler, _UnarySampler)
    else:
        assert isinstance(mcmc.sampler, _MultiSampler)
Ejemplo n.º 4
0
def test_predictive(num_samples, parallel):
    model, data, true_probs = beta_bernoulli()
    init_params, potential_fn, transforms, _ = initialize_model(
        model, model_args=(data, ))
    nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms)
    mcmc = MCMC(nuts_kernel, 100, initial_params=init_params, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    with ignore_experimental_warning():
        with optional(pytest.warns(UserWarning), num_samples
                      not in (None, 100)):
            predictive_samples = predictive(model,
                                            samples,
                                            num_samples=num_samples,
                                            return_sites=["beta", "obs"],
                                            parallel=parallel)

    # check shapes
    assert predictive_samples["beta"].shape == (100, 5)
    assert predictive_samples["obs"].shape == (100, 1000, 5)

    # check sample mean
    assert_close(predictive_samples["obs"].reshape([-1, 5]).mean(0),
                 true_probs,
                 rtol=0.1)
Ejemplo n.º 5
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    guide(**kwargs)
    neutra = NeuTraReparam(guide)
    reparam_model = neutra.reparam(model)
    _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(
        reparam_model, model_kwargs=kwargs
    )
    latent_x = list(init_params.values())[0]
    transformed_params = neutra.transform_sample(latent_x)
    pe_transformed = pe_fn_neutra(init_params)
    neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms)
    latent_y = neutra_transform(latent_x)
    log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()})
    assert_close(pe_transformed, pe - log_det_jacobian)
Ejemplo n.º 6
0
 def setup(self, warmup_steps, data):
     self.data = data
     init_params, potential_fn, transforms, model_trace = initialize_model(self.model,
                                                                           model_args=(data,))
     if self._initial_params is None:
         self._initial_params = init_params
     if self.transforms is None:
         self.transforms = transforms
     self._prototype_trace = model_trace
Ejemplo n.º 7
0
def test_potential_fn_pickling(jit):
    data = dist.Bernoulli(torch.tensor(
        [0.8, 0.2])).sample(sample_shape=(torch.Size((1000, ))))
    _, potential_fn, _, _ = initialize_model(_beta_bernoulli, (data, ),
                                             jit_compile=jit,
                                             skip_jit_warnings=True)
    test_data = {'p_latent': torch.tensor([0.2, 0.6])}
    assert_close(
        pickle.loads(pickle.dumps(potential_fn))(test_data),
        potential_fn(test_data))
Ejemplo n.º 8
0
def test_potential_fn_pickling(jit):
    data = dist.Bernoulli(torch.tensor([0.8, 0.2])).sample(sample_shape=(torch.Size((1000,))))
    _, potential_fn, _, _ = initialize_model(_beta_bernoulli, (data,), jit_compile=jit,
                                             skip_jit_warnings=True)
    test_data = {'p_latent': torch.tensor([0.2, 0.6])}
    buffer = io.BytesIO()
    torch.save(potential_fn, buffer)
    buffer.seek(0)
    deser_potential_fn = torch.load(buffer)
    assert_close(deser_potential_fn(test_data), potential_fn(test_data))
Ejemplo n.º 9
0
    def run(self, *args, **kwargs):
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with pyro.validation_enabled(not self.disable_validation):
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            if hasattr(self.kernel, 'transforms'):
                if self.kernel.transforms is not None:
                    self.transforms = self.kernel.transforms
            elif self.kernel.model:
                _, _, self.transforms, _ = initialize_model(
                    self.kernel.model, model_args=args, model_kwargs=kwargs)
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
Ejemplo n.º 10
0
    def setup(self, warmup_steps, *args, **kwargs):
        """Sets up the sampler."""
        self._warmup_steps = warmup_steps

        init_params, _, _, _ = initialize_model(
            self.model,
            args,
            kwargs,
        )
        if self._initial_params is None:
            self._initial_params = init_params

        self._model_args = args
        self._model_kwargs = kwargs
Ejemplo n.º 11
0
Archivo: slice.py Proyecto: boyali/sbi
 def _initialize_model_properties(self, model_args, model_kwargs):
     init_params, potential_fn, transforms, trace = initialize_model(
         self.model,
         model_args,
         model_kwargs,
         transforms=self.transforms,
         max_plate_nesting=self._max_plate_nesting,
         jit_compile=self._jit_compile,
         jit_options=self._jit_options,
         skip_jit_warnings=self._ignore_jit_warnings,
     )
     self.potential_fn = potential_fn
     self.transforms = transforms
     if self._initial_params is None:
         self.initial_params = init_params
     self._prototype_trace = trace
Ejemplo n.º 12
0
def test_mcmc_diagnostics(num_chains):
    data = torch.tensor([2.0]).repeat(3)
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data,),
                                                        num_chains=num_chains)
    kernel = PriorKernel(normal_normal_model)
    mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn",
                initial_params=initial_params, transforms=transforms)
    mcmc.run(data)
    if not torch.backends.mkl.is_available():
        pytest.skip()
    diagnostics = mcmc.diagnostics()
    assert diagnostics["y"]["n_eff"].shape == data.shape
    assert diagnostics["y"]["r_hat"].shape == data.shape
    assert diagnostics["dummy_key"] == {'chain {}'.format(i): 'dummy_value'
                                        for i in range(num_chains)}
Ejemplo n.º 13
0
def test_mcmc_diagnostics(run_mcmc_cls, num_chains):
    data = torch.tensor([2.0]).repeat(3)
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data, ),
                                                        num_chains=num_chains)
    kernel = PriorKernel(normal_normal_model)
    if run_mcmc_cls == run_default_mcmc:
        mcmc = MCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            mp_context="spawn",
            initial_params=initial_params,
            transforms=transforms,
        )
    else:
        mcmc = StreamingMCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            initial_params=initial_params,
            transforms=transforms,
        )
    mcmc.run(data)
    if not torch.backends.mkl.is_available():
        pytest.skip()
    diagnostics = mcmc.diagnostics()
    if run_mcmc_cls == run_default_mcmc:  # TODO n_eff for streaming MCMC
        assert diagnostics["y"]["n_eff"].shape == data.shape
    assert diagnostics["y"]["r_hat"].shape == data.shape
    assert diagnostics["dummy_key"] == {
        "chain {}".format(i): "dummy_value"
        for i in range(num_chains)
    }
Ejemplo n.º 14
0
def infer_sample(
    cond_model,
    n_steps,
    warmup_steps,
    n_chains=1,
    device="cpu",
    guidefile=None,
    guide_conf=None,
    mcmcfile=None,
):
    """Runs the NUTS HMC algorithm.

    Saves the samples and weights as well as a netcdf file for the run.

    Parameters
    ----------
    args : dict
        Command line arguments.
    cond_model : callable
        Model conditioned on an observed images.
    """
    initial_params, potential_fn, transforms, prototype_trace = util.initialize_model(
        cond_model)

    if guidefile is not None:
        guide = init_guide(cond_model,
                           guide_conf,
                           guidefile=guidefile,
                           device=device)
        sample = guide()
        for key in initial_params.keys():
            initial_params[key] = transforms[key](sample[key].detach())

    # FIXME: In the case of DiagonalNormal, results have to be mapped back onto unpacked latents
    if guide_conf["type"] == "DiagonalNormal":
        transform = guide.get_transform()
        unpack_fn = lambda u: guide.unpack_latent(u)
        potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
        initial_params = {"z": torch.zeros(guide.get_posterior().shape())}
        transforms = None

    def fun(*args, **kwargs):
        res = potential_fn(*args, **kwargs)
        return res

    nuts_kernel = NUTS(
        potential_fn=fun,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        full_mass=False,
        use_multinomial_sampling=True,
        jit_compile=False,
        max_tree_depth=10,
        transforms=transforms,
        step_size=1.0,
    )
    nuts_kernel.initial_params = initial_params

    # Run
    mcmc = MCMC(
        nuts_kernel,
        n_steps,
        warmup_steps=warmup_steps,
        initial_params=initial_params,
        num_chains=n_chains,
    )
    mcmc.run()

    # This block lets the posterior be pickled
    mcmc.sampler = None
    mcmc.kernel.potential_fn = None
    mcmc._cache = {}

    print(f"Saving MCMC object to {mcmcfile}")
    with open(mcmcfile, "wb") as f:
        pickle.dump(mcmc, f, pickle.HIGHEST_PROTOCOL)
Ejemplo n.º 15
0
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    init_params, potential_fn, transforms, _ = initialize_model(fully_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(summary(samples_fully_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_log_posterior_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    init_params, potential_fn, transforms, _ = initialize_model(not_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(summary(samples_not_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_log_posterior_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    init_params, potential_fn, transforms, _ = initialize_model(partially_pooled, model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn)

    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(summary(samples_partially_pooled,
                         sites=["phi"],
                         player_names=player_names,
                         diagnostics=diagnostics)["phi"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_log_posterior_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    init_params, potential_fn, transforms, _ = initialize_model(partially_pooled_with_logit,
                                                                model_args=(at_bats, hits),
                                                                num_chains=args.num_chains)
    nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    diagnostics = mcmc.diagnostics()
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(summary(samples_partially_pooled_logit,
                         sites=["alpha"],
                         player_names=player_names,
                         transforms={"alpha": torch.sigmoid},
                         diagnostics=diagnostics)["alpha"])
    num_divergences = sum(map(len, diagnostics["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_log_posterior_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                   baseball_dataset)
        irt_model = irt_model_2pl
    elif args.irt_model == '3pl':
        if args.hierarchical:
            irt_model = irt_model_3pl_hierarchical
        else:
            irt_model = irt_model_3pl
    else:
        raise Exception('irt_model {} not supported.'.format(args.irt_model))

    init_params, potential_fn, transforms, _ = initialize_model(
        irt_model, 
        model_args=(
            args.ability_dim, 
            num_person, 
            num_item, 
            device, 
            response, 
            mask, 
            1,
        ),
        num_chains=args.num_chains,
    )

    start_time = time.time()

    nuts_kernel = NUTS(potential_fn = potential_fn)
    mcmc = MCMC(
        nuts_kernel,
        num_samples = args.num_samples,
        warmup_steps = args.num_warmup,
        num_chains = args.num_chains,
Ejemplo n.º 17
0
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    # In this model, we illustrate how to use MCMC with general potential_fn.
    init_params, potential_fn, transforms, _ = initialize_model(
        fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
        jit_compile=args.jit, skip_jit_warnings=True)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["alpha"],
                                   player_names=player_names,
                                   transforms={"alpha": torch.sigmoid},
                                   diagnostics=True,
                                   group_by_chain=True)["alpha"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                    baseball_dataset)
Ejemplo n.º 18
0
    def run(self, *args, **kwargs):
        """
        Run MCMC to generate samples and populate `self._samples`.

        Example usage:

        .. code-block:: python

            def model(data):
                ...

            nuts_kernel = NUTS(model)
            mcmc = MCMC(nuts_kernel, num_samples=500)
            mcmc.run(data)
            samples = mcmc.get_samples()

        :param args: optional arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        :param kwargs: optional keywords arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with optional(pyro.validation_enabled(not self.disable_validation),
                      self.disable_validation is not None):
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            # Try to initialize kernel.transforms using kernel.setup().
            if getattr(self.kernel, "transforms", None) is None:
                warmup_steps = 0
                self.kernel.setup(warmup_steps, *args, **kwargs)
            # Use `kernel.transforms` when available
            if getattr(self.kernel, "transforms", None) is not None:
                self.transforms = self.kernel.transforms
            # Else, get transforms from model (e.g. in multiprocessing).
            elif self.kernel.model:
                _, _, self.transforms, _ = initialize_model(
                    self.kernel.model,
                    model_args=args,
                    model_kwargs=kwargs,
                    initial_params={})
            # Assign default value
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)