Exemple #1
0
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    # In this model, we illustrate how to use MCMC with general potential_fn.
    init_params, potential_fn, transforms, _ = initialize_model(
        fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
        jit_compile=args.jit, skip_jit_warnings=True)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["alpha"],
                                   player_names=player_names,
                                   transforms={"alpha": torch.sigmoid},
                                   diagnostics=True,
                                   group_by_chain=True)["alpha"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                    baseball_dataset)
Exemple #2
0
from pyro.primitives import sample
import torch
import numpy as np
import pyro
from pyro import sample
from pyro.infer import NUTS, MCMC
from pyro.distributions import Normal
from matplotlib import pyplot as plt


def bad():
    x = sample('x', pyro.distributions.Normal(0, 1))
    for i in range(10):
        x = sample('x', pyro.distributions.Normal(x, 3))
    x


nuts_kernel = NUTS(bad)

mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
mcmc.run()
mcmc.summary()
samples = mcmc.get_samples()

print(samples.keys())

fig, ax = plt.subplots()
ax.hist(np.array(samples["x"]), bins=50)
plt.show()
Exemple #3
0
        sys.stderr.write("Requires Python 3\n")

    genr = Decoder()
    genr.load_state_dict(torch.load('gaae-decd-1024.tch'))
    genr.eval()

    data = qPCRData('second.txt', randomize=False, test=False)

    # Do it with CUDA if possible.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        genr.cuda()

    model = GeneratorModel(genr)

    nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=True)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=2000)
    for batch in data.batches(btchsz=8192, randomize=False, test=False):
        obs = batch[:, 90:].to(device)
        mcmc.run(obs)
        z = mcmc.get_samples()['z']
        # Propagate forward and sample observable 'x'.
        with torch.no_grad():
            mu, sd = genr(z)
        for i in range(batch.shape[0]):
            x = Normal(mu[:, i, :90], sd[:, i, :90]).sample()
            orig = batch[i, 90:].expand([1000, 45])
            out = torch.cat([x, orig], dim=1)
            np.savetxt(sys.stdout, out.cpu().numpy(), fmt='%.4f')
Exemple #4
0
import argparse
import logging
import math

import pandas as pd
import torch

import pyro
from pyro.distributions import Beta, Binomial, HalfCauchy, Normal, Pareto, Uniform
from pyro.distributions.util import scalar_like
from pyro.infer import MCMC, NUTS, Predictive
from pyro.infer.mcmc.util import initialize_model, summary
from pyro.util import ignore_experimental_warning

"""
Example has been adapted from [1]. It demonstrates how to do Bayesian inference using
NUTS (or, HMC) in Pyro, and use of some common inference utilities.

As in the Stan tutorial, this uses the small baseball dataset of Efron and Morris [2]
to estimate players' batting average which is the fraction of times a player got a
base hit out of the number of times they went up at bat.

The dataset separates the initial 45 at-bats statistics from the remaining season.
We use the hits data from the initial 45 at-bats to estimate the batting average
for each player. We then use the remaining season's data to validate the predictions
from our models.

Three models are evaluated:
 - Complete pooling model: The success probability of scoring a hit is shared
     amongst all players.
Exemple #5
0
def linear():
    x_data, y_data = [1, 2, 3, 4, 5,
                      6], torch.tensor([2.2, 4.2, 5.5, 8.3, 9.9, 12.1])
    k = sample('k', pyro.distributions.Normal(0, 1))
    if k < 0:
        slope = sample('slope', Normal(0, 5))
    else:
        slope = sample('slope', pyro.distributions.Bernoulli(0.5))

    bias = sample('bias', Normal(0, 5))

    for i in range(len(x_data)):
        x = x_data[i]
        mu = x * slope + bias
        y = sample(f"y_{i}", Normal(mu, 1), obs=y_data[i])


nuts_kernel = NUTS(linear)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=10)
mcmc.run()
mcmc.summary()
samples = mcmc.get_samples()

print(samples)

fig, ax = plt.subplots()
ax.hist2d(np.array(samples["slope"]), np.array(samples["bias"]), bins=30)
plt.show()
Exemple #6
0
def run_hmc(args, model):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
    mcmc.run(args.param_a, args.param_b)
    mcmc.summary()
    return mcmc
               for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
               if k != "obs"}




for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")



from pyro.infer import MCMC, NUTS


nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}




for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")


Exemple #8
0
    def fit_mcmc(self, **options):
        r"""
        Runs NUTS inference to generate posterior samples.

        This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
        :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
        attribute on completion.

        This uses an asymptotically exact enumeration-based model when
        ``num_quant_bins > 1``, and a cheaper moment-matched approximate model
        when ``num_quant_bins == 1``.

        :param \*\*options: Options passed to
            :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
            pulled out and have special meaning.
        :param int num_samples: Number of posterior samples to draw via mcmc.
            Defaults to 100.
        :param int max_tree_depth: (Default 5). Max tree depth of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
        :param full_mass: Specification of mass matrix of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
            over global random variables.
        :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
            of an arrowhead matrix versus simply as a block. Defaults to False.
        :param int num_quant_bins: If greater than 1, use asymptotically exact
            inference via local enumeration over this many quantization bins.
            If equal to 1, use continuous-valued relaxed approximate inference.
            Note that computational cost is exponential in `num_quant_bins`.
            Defaults to 1 for relaxed inference.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
            Defaults to True.
        :param int haar_full_mass: Number of low frequency Haar components to
            include in the full mass matrix. If ``haar=False`` then this is
            ignored. Defaults to 10.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
        :rtype: ~pyro.infer.mcmc.api.MCMC
        """
        _require_double_precision()

        # Parse options, saving some for use in .predict().
        num_samples = options.setdefault("num_samples", 100)
        num_chains = options.setdefault("num_chains", 1)
        self.num_quant_bins = options.pop("num_quant_bins", 1)
        assert isinstance(self.num_quant_bins, int)
        assert self.num_quant_bins >= 1
        self.relaxed = self.num_quant_bins == 1

        # Setup Haar wavelet transform.
        haar = options.pop("haar", False)
        haar_full_mass = options.pop("haar_full_mass", 10)
        full_mass = options.pop("full_mass", self.full_mass)
        assert isinstance(haar, bool)
        assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
        assert isinstance(full_mass, (bool, list))
        haar_full_mass = min(haar_full_mass, self.duration)
        if not haar:
            haar_full_mass = 0
        if full_mass is True:
            haar_full_mass = 0  # No need to split.
        elif haar_full_mass >= self.duration:
            full_mass = True  # Effectively full mass.
            haar_full_mass = 0
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
        if haar_full_mass:
            assert full_mass and isinstance(full_mass, list)
            full_mass = full_mass[:]
            full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        init_strategy = init_to_generated(
            generate=functools.partial(self._heuristic, haar, **heuristic_options))

        # Configure a kernel.
        logger.info("Running inference...")
        model = self._relaxed_model if self.relaxed else self._quantized_model
        if haar:
            model = haar.reparam(model)
        kernel = NUTS(model,
                      full_mass=full_mass,
                      init_strategy=init_strategy,
                      max_plate_nesting=self.max_plate_nesting,
                      jit_compile=options.pop("jit_compile", False),
                      jit_options=options.pop("jit_options", None),
                      ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
                      target_accept_prob=options.pop("target_accept_prob", 0.8),
                      max_tree_depth=options.pop("max_tree_depth", 5))
        if options.pop("arrowhead_mass", False):
            kernel.mass_matrix_adapter = ArrowheadMassMatrix()

        # Run mcmc.
        options.setdefault("disable_validation", None)
        mcmc = MCMC(kernel, **options)
        mcmc.run()
        self.samples = mcmc.get_samples()
        if haar:
            haar.aux_to_user(self.samples)

        # Unsqueeze samples to align particle dim for use in poutine.condition.
        # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
        model = self._relaxed_model if self.relaxed else self._quantized_model
        self.samples = align_samples(self.samples, model,
                                     particle_dim=-1 - self.max_plate_nesting)
        assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return mcmc  # E.g. so user can run mcmc.summary().
Exemple #9
0
 def draw_posterior_samples(self,num_samples):
     nuts_kernel = NUTS(self.model, adapt_step_size=True)
     mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=300)
     mcmc.run()
     res = mcmc.get_samples()
     return res
Exemple #10
0
        rate = pyro.sample("small_spike_rate_post_spike",
                           dist.Uniform(0.4, 0.9))
        for k in range(12 - small_spike_peakStart - 5):
            sellprices.append(rate * basePrice)
            rate -= 0.03
            rate -= pyro.sample("small_spike_final_dec_%d" % k,
                                dist.Uniform(0., 0.02))
        sellprices = torch.ceil(torch.stack(sellprices))
        print("Small spike sellprices: ", sellprices)
    else:
        raise ValueError("Invalid nextPattern %d" % nextPattern)
    pyro.sample("obs", dist.Delta(sellprices).to_event(1))


if __name__ == "__main__":
    calculate_turnip_prices()

    conditioned_model = poutine.condition(calculate_turnip_prices,
                                          data={
                                              "obs":
                                              torch.tensor([
                                                  87., 83., 79., 75., 72., 68.,
                                                  64., 106., 115., 144., 185.,
                                                  138.
                                              ])
                                          })
    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100, num_chains=1)
    mcmc.run()
    mcmc.summary(prob=0.5)
plt.show()
#%%
for i in range(d_latent):
    plt.plot(s_truth[n,:,i], color = "green")
    plt.ylim((-2,2))
    for _ in range(50):
        s_pred = (system.guide(X, actions)(X,actions)).detach()[:,:-1]
        
        plt.plot(s_pred[n,:,i], alpha = 0.2)
    plt.show()

#%%

#%% MCMC
from pyro.infer import NUTS, MCMC, EmpiricalMarginal
nuts_kernel = NUTS(system.model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=30)
mcmc_run = mcmc.run(X, actions)
#%%
hmc_samples = {k : v.detach().unsqueeze(1) for k, v in mcmc.get_samples().items()}
for key,val in hmc_samples.items():
    print(key)
#%%
#hmc_samples["network$$$a_emb.weight"][0]
hmc_samples["network$$$a_emb.weight"].mean(0)
#%%
for a in range(num_actions):
    plt.plot(hmc_samples["network$$$a_emb.weight"][:,0,a,:])

plt.plot(hmc_samples["network$$$a_emb.weight"][:,0,2,1])
#%%
Exemple #12
0
def do_fixed_structure_hmc_with_constraint_penalties(
        grammar,
        original_tree,
        num_samples=100,
        subsample_step=5,
        verbose=0,
        kernel_type="NUTS",
        fix_observeds=False,
        with_nonpenetration=False,
        with_static_stability=False,
        zmq_url=None,
        prefix="hmc_sample",
        constraints=[],
        structure_vis_kwargs={},
        **kwargs):
    ''' Given a scene tree, resample its continuous variables
    (i.e. the node poses) while keeping the root and observed
    node poses fixed, and trying to keep the constraints implied
    by the tree and grammar satisfied.. Returns a population of trees sampled
    from the joint distribution over node poses given the fixed structure.

    Verbose = 0: Print nothing
    Verbose = 1: Print updates about accept rate and steps
    Verbose = 2: Print NLP output and scoring info
    '''

    # Strategy sketch:
    # - Initialize at the current scene tree, asserting that it's a feasible configuration.
    # - Form a probabilistic model that samples all of the node poses,
    #   and uses observe() statements to implement constraint factors as tightly peaked
    #   energy terms.
    # - Use Pyro HMC to sample from the model.

    # Make a bookkeeping copy of the tree
    scene_tree = deepcopy(original_tree)

    # Do steps of random-walk MCMC on those variables.
    initial_score = scene_tree.score()
    assert torch.isfinite(initial_score), "Bad initialization for MCMC."

    if with_nonpenetration:
        constraints.append(NonpenetrationConstraint(0.00))

    # Form probabilistic model
    root = scene_tree.get_root()

    def model():
        # Resample the continuous structure of the tree.
        node_queue = [root]
        while len(node_queue) > 0:
            parent = node_queue.pop(0)
            children, rules = scene_tree.get_children_and_rules(parent)
            for child, rule in zip(children, rules):
                with scope(prefix=parent.name):
                    rule.sample_child(parent, child)
                node_queue.append(child)

        # Implement observation constraints
        if fix_observeds:
            xyz_observed_variance = 1E-2
            rot_observed_variance = 1E-2
            for node, original_node in zip(scene_tree.nodes,
                                           original_tree.nodes):
                if node.observed:
                    xyz_observed_dist = dist.Normal(original_node.translation,
                                                    xyz_observed_variance)
                    rot_observed_dist = dist.Normal(original_node.rotation,
                                                    rot_observed_variance)
                    pyro.sample("%s_xyz_observed" % node.name,
                                xyz_observed_dist,
                                obs=node.translation)
                    pyro.sample("%s_rotation_observed" % node.name,
                                rot_observed_dist,
                                obs=node.rotation)

        for k, constraint in enumerate(constraints):
            clamped_error_distribution = dist.Normal(0., 0.001)
            violation, _, _ = constraint.eval_violation(scene_tree)
            positive_violations = torch.clamp(violation, 0., np.inf)
            pyro.sample("%s_%d_err" % (type(constraint).__name__, k),
                        clamped_error_distribution,
                        obs=positive_violations)

    # Ugh, I shouldn't need to manually reproduce the site names here.
    # Can I rearrange how traces get assembled to extract these?
    initial_values = {}
    for parent in original_tree.nodes:
        children, rules = original_tree.get_children_and_rules(parent)
        for child, rule in zip(children, rules):
            for key, site_value in rule.get_site_values(parent, child).items():
                initial_values["%s/%s/%s" % (parent.name, child.name,
                                             key)] = site_value.value
    trace = pyro.poutine.trace(model).get_trace()
    for key in initial_values.keys():
        if key not in trace.nodes.keys():
            print("Trace keys: ", trace.nodes.keys())
            print("Initial values keys: ", initial_values.keys())
            raise ValueError("%s not in trace keys" % key)

    print("Initial trace log prob: ", trace.log_prob_sum())
    # If I let MCMC auto-tune its step size, it seems to do well,
    # but sometimes seems to get lost, and then gets stuck with big step size and
    # zero acceptances.
    if kernel_type == "NUTS":
        kernel = NUTS(model,
                      init_strategy=pyro.infer.autoguide.init_to_value(
                          values=initial_values),
                      **kwargs)
    elif kernel_type is "HMC":
        kernel = HMC(model,
                     init_strategy=pyro.infer.autoguide.init_to_value(
                         values=initial_values),
                     **kwargs)
    else:
        raise NotImplementedError(kernel_type)

    # Get MBP for viz
    if zmq_url is not None:
        builder, mbp, sg, node_to_free_body_ids_map, body_id_to_node_map = compile_scene_tree_to_mbp_and_sg(
            scene_tree)
        mbp.Finalize()
        visualizer = ConnectMeshcatVisualizer(builder,
                                              sg,
                                              zmq_url=zmq_url,
                                              prefix=prefix)
        diagram = builder.Build()
        diagram_context = diagram.CreateDefaultContext()
        mbp_context = diagram.GetMutableSubsystemContext(mbp, diagram_context)
        vis_context = diagram.GetMutableSubsystemContext(
            visualizer, diagram_context)
        visualizer.load()

        def hook_fn(kernel, samples, stage, i):
            # Set MBP context to
            for node, body_ids in node_to_free_body_ids_map.items():
                for body_id in body_ids:
                    mbp.SetFreeBodyPose(mbp_context, mbp.get_body(body_id),
                                        torch_tf_to_drake_tf(node.tf))
            diagram.Publish(diagram_context)
            draw_scene_tree_structure_meshcat(scene_tree,
                                              zmq_url=zmq_url,
                                              prefix=prefix + "/structure",
                                              delete=False,
                                              **structure_vis_kwargs)
            time.sleep(0.1)
    else:
        hook_fn = None

    mcmc = MCMC(kernel,
                num_samples=num_samples,
                warmup_steps=min(int(num_samples / 2), 50),
                num_chains=1,
                disable_progbar=(verbose == -1),
                hook_fn=hook_fn)
    mcmc.run()
    if verbose > 1:
        mcmc.summary(prob=0.5)

    samples = mcmc.get_samples()
    sampled_trees = []
    for k in range(0, num_samples, subsample_step):
        condition = {key: value[k, ...] for key, value in samples.items()}
        with pyro.condition(data=condition):
            model()
        sampled_trees.append(deepcopy(scene_tree))

    return sampled_trees
Exemple #13
0
traces = [system.guide().detach().unsqueeze(0) for _ in range(100)]
traces = torch.cat(traces)
avgtrace = traces.mean(0)
for i in range(len(traces)):
    plt.plot(traces[i, :, 0], traces[i, :, 1], alpha=0.05)
plt.plot(avgtrace[:, 0], avgtrace[:, 1], color="red")
#plt.plot(pyro.param("z-mu").detach())

# %%
pyro.param("z-scale").mean()

# %% MCMC

from pyro.infer import MCMC, NUTS

nuts_kernel = NUTS(system.model)
#%%
mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=50)

mcmc.run()

# %% VISUALIZE TRACES

hmc_samples = [v.detach().unsqueeze(1) for k, v in mcmc.get_samples().items()]
traces = torch.cat(hmc_samples, dim=1)
avgtrace = traces.mean(0)
for i in range(len(traces)):
    plt.plot(traces[i, :, 0], traces[i, :, 1], alpha=0.05)
plt.plot(avgtrace[:, 0], avgtrace[:, 1], color="red", alpha=0.9)

# %%
Exemple #14
0
def find_n_modes(args):
    # GMM with arbitraty many components
    var = 1
    d = args["z_dim"]
    args["locs"] = make_gaussians(args, d, var)
    args['covs'] = [var * torch.eye(d, dtype=torch.float32, device=device)] * args[
        'num_gauss']  # list of covariance matrices for each of these gaussians

    target = GMM_target2(kwargs=args)

    #######################################################################################
    ###### NUTS ########
    def energy(z):
        z = z['points']
        return -target.get_logdensity(z)

    prior = get_prior(args, target)

    kernel = NUTS(potential_fn=energy)
    n_stop = 1  # number of time we stop to check n_modes (to allow fair comparison with other tests, we take the greatest number of modes retrieved)
    warmup_steps = 4000
    n_chains = args['n_chains']
    num_samples = 8000 // n_chains
    nuts = torch.tensor([], device=device)
    nuts_ungrouped = torch.tensor([], device=device)
    if n_chains > 1:
        best_n_modes = 0.
    else:
        best_n_modes = np.zeros(n_chains)

#     pdb.set_trace()
    init_samples = prior.sample((n_chains, args.z_dim))
#     init_params = {'points': init_samples}
    
    ## First we run warmup
#     current_samples = torch.tensor([], device=args.device)
#     for ind in range(n_chains):
#         mcmc = MCMC(kernel=kernel, num_samples=1,
#                     initial_params={'points': init_samples[ind]},
#                     num_chains=1, warmup_steps=warmup_steps)
#         mcmc.run()
#         current_samples = torch.cat([current_samples, mcmc.get_samples(group_by_chain=True)['points'].view(1, -1, d)])
#     init_samples = current_samples.view(n_chains, -1)
    
#     init_samples = torch.cat([chain[None] for chain in args.locs])
#     init_params = {'points': init_samples}
    for i in range(n_stop): ## n_stop -- how often we check n modes
        current_samples = torch.tensor([], device=args.device)
        for ind in range(n_chains):
            mcmc = MCMC(kernel=kernel, num_samples=num_samples // n_stop,
                        initial_params={'points': init_samples[ind]},
                        num_chains=1, warmup_steps=warmup_steps)
            mcmc.run()
            current_samples = torch.cat([current_samples, mcmc.get_samples(group_by_chain=True)['points'].view(1, -1, d)])
#         pdb.set_trace()
        nuts = torch.cat([nuts, current_samples], dim=1)
        nuts_ungrouped = torch.cat([nuts_ungrouped, nuts.view(-1, d)],
                                   dim=0)
        init_samples = nuts[:, -1]  # last sample of each chain (shape = n_chains x z_dim)
#         init_params = {'points': init_samples}
        # pdb.set_trace()
        new_n_modes = n_modes(args, nuts_ungrouped, d, var)
        if new_n_modes > best_n_modes:
            best_n_modes = new_n_modes
        if best_n_modes == 8:
            break
    print(best_n_modes)

    return best_n_modes
Exemple #15
0
                                posterior_samples=posterior_samples)(y_obs,
                                                                     None)
    posterior_summary = get_summary_table(posterior_pred, sites=["y_obs"])
    return posterior_summary


def sample_prior_pred(model, numofsamples, y_obs):
    prior_pred = Predictive(model, {}, num_samples=numofsamples)
    prior_summary = get_summary_table(prior_pred, sites=["y_obs"])
    return prior_summary


np.random.seed(0)
Y = torch.Tensor(stats.bernoulli(0.7).rvs(20))

nuts_kernel = NUTS(partial_pooled)
mcmc = MCMC(nuts_kernel, 1000, num_chains=1)
mcmc.run(Y)
trace = mcmc.get_samples()
idata = az.from_pyro(trace)
observedY = partial_pooled(Y)
pred_dist = (sample_prior_pred(partial_pooled, 1000, observedY),
             sample_posterior_pred(partial_pooled, idata, observedY))

fig, ax = plt.subplots()
az.plot_dist(pred_dists[0].sum(1),
             hist_kwargs={
                 "color": "0.5",
                 "bins": range(0, 22)
             })
ax.set_title(f"Prior predictive distribution", fontweight='bold')