Пример #1
0
    def update(self, logits, labels):
        """ Performs an update given new observations.

        Args:
            logits: tensor ; shape (batch_size, num_classes)
            labels: tensor ; shape (batch_size, )
        """
        assert len(
            labels.shape
        ) == 1, 'Got label tensor with shape {} -- labels must be dense'.format(
            labels.shape)
        assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format(
            logits.shape)
        assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \
            .format(logits.shape[0], labels.shape[0])

        logits = logits.detach().clone().requires_grad_()
        labels = labels.detach().clone()

        batch_size = labels.shape[0]
        if self.verbose:
            print(
                '----| Updating HBC model\n--------| Got a batch size of: {}'.
                format(batch_size))

        # TODO
        # self._update_prior_params()
        if self.verbose:
            print('--------| Updated priors: {}'.format(self.prior_params))
            print('--------| Running inference ')
        nuts_kernel = NUTS(bvs_model, **self.NUTS_params)
        self.mcmc = MCMC(
            nuts_kernel, **self.mcmc_params,
            disable_progbar=not self.verbose)  # Progbar if verbose
        self.mcmc.run(self.prior_params, logits, labels)

        # TODO
        # self._update_posterior_params()
        self.timestep += 1

        return self.mcmc
Пример #2
0
def numpyro_schools_model(data, draws, chains):
    """Centered eight schools implementation in NumPyro."""
    from jax.random import PRNGKey
    from numpyro.infer import MCMC, NUTS

    mcmc = MCMC(
        NUTS(_numpyro_noncentered_model),
        num_warmup=draws,
        num_samples=draws,
        num_chains=chains,
        chain_method="sequential",
    )
    mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)

    # This block lets the posterior be pickled
    mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
    mcmc.sampler._init_fn = None  # pylint: disable=protected-access
    mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
    mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
    mcmc._cache = {}  # pylint: disable=protected-access
    return mcmc
Пример #3
0
def pyro_noncentered_schools(data, draws, chains):
    """Non-centered eight schools implementation in Pyro."""
    import torch
    from pyro.infer import MCMC, NUTS

    y = torch.from_numpy(data["y"]).float()
    sigma = torch.from_numpy(data["sigma"]).float()

    nuts_kernel = NUTS(_pyro_noncentered_model,
                       jit_compile=True,
                       ignore_jit_warnings=True)
    posterior = MCMC(nuts_kernel,
                     num_samples=draws,
                     warmup_steps=draws,
                     num_chains=chains)
    posterior.run(data["J"], sigma, y)

    # This block lets the posterior be pickled
    posterior.sampler = None
    posterior.kernel.potential_fn = None
    return posterior
Пример #4
0
def test_neals_funnel_smoke(Guide, jit):
    dim = 10

    guide = Guide(neals_funnel)
    svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(10):
        svi.step(dim)

    neutra = NeuTraReparam(guide.requires_grad_(False))
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model, jit_compile=jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts, num_samples=10, warmup_steps=10)
    mcmc.run(dim)
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(
        samples["y_shared_latent"].unsqueeze(-2)
    )
    assert "x" in transformed_samples
    assert "y" in transformed_samples
Пример #5
0
def Inference_MCMC(model,
                   data,
                   polls,
                   n_samples=500,
                   n_warmup=500,
                   n_chains=1,
                   max_tree_depth=6):

    nuts_kernel = NUTS(model,
                       adapt_step_size=True,
                       jit_compile=True,
                       ignore_jit_warnings=True,
                       max_tree_depth=max_tree_depth)

    mcmc = MCMC(nuts_kernel,
                num_samples=n_samples,
                warmup_steps=n_warmup,
                num_chains=n_chains)

    mcmc.run(data, polls)

    # the samples that were not rejected;
    # actual samples from the posterior dist
    posterior_samples = mcmc.get_samples()

    # turning to a dict
    hmc_samples = {
        k: v.detach().cpu().numpy()
        for k, v in mcmc.get_samples().items()
    }

    return posterior_samples, hmc_samples
Пример #6
0
    def __init__(self,
                 model,
                 data,
                 covariates=None,
                 *,
                 num_warmup=1000,
                 num_samples=1000,
                 num_chains=1,
                 time_reparam=None,
                 dense_mass=False,
                 jit_compile=False,
                 max_tree_depth=10):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        if time_reparam == "haar":
            model = poutine.reparam(model, time_reparam_haar)
        elif time_reparam == "dct":
            model = poutine.reparam(model, time_reparam_dct)
        elif time_reparam is not None:
            raise ValueError("unknown time_reparam: {}".format(time_reparam))
        self.model = model
        max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates),
                                                     {})
        self.max_plate_nesting = max(max_plate_nesting,
                                     1)  # force a time plate

        kernel = NUTS(model,
                      full_mass=dense_mass,
                      jit_compile=jit_compile,
                      ignore_jit_warnings=True,
                      max_tree_depth=max_tree_depth,
                      max_plate_nesting=max_plate_nesting)
        mcmc = MCMC(kernel,
                    warmup_steps=num_warmup,
                    num_samples=num_samples,
                    num_chains=num_chains)
        mcmc.run(data, covariates)
        # conditions to compute rhat
        if (num_chains == 1 and num_samples >= 4) or (num_chains > 1
                                                      and num_samples >= 2):
            mcmc.summary()

        # inspect the model with particles plate = 1, so that we can reshape samples to
        # add any missing plate dim in front.
        with poutine.trace() as tr:
            with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1):
                model(data, covariates)

        self._trace = tr.trace
        self._samples = mcmc.get_samples()
        self._num_samples = num_samples * num_chains
        for name, node in list(self._trace.nodes.items()):
            if name not in self._samples:
                del self._trace.nodes[name]
Пример #7
0
def run_hmc(args, model):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel,
                warmup_steps=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(args.param_a, args.param_b)
    mcmc.summary()
    return mcmc
Пример #8
0
def main(args):
    nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(model, data.sigma, data.y)
    mcmc.summary(prob=0.5)
Пример #9
0
def sample_model(chat,
                 mhat,
                 varpihat,
                 sigmac,
                 sigmam,
                 sigmavarpi,
                 dustco_c,
                 dustco_m,
                 theta_0_mcmc,
                 nsamples=100,
                 nwalkers=1):
    objective = Objective(chat, mhat, varpihat, sigmac, sigmam, sigmavarpi,
                          dustco_c, dustco_m)
    #print(objective.logjoint())
    objective.logjoint()
    #nuts_kernel = NUTS(objective.logjoint, jit_compile=True, ignore_jit_warnings=True)
    #mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=100, num_chains=2, mp_context='spawn')

    try:
        with open('savemcmc_{}.pkl'.format(ind), 'rb') as f:
            mcmc = pickle.load(f)
    except IOError:
        nuts_kernel = NUTS(objective.logjoint,
                           jit_compile=True,
                           ignore_jit_warnings=False)
        mcmc = MCMC(nuts_kernel,
                    num_samples=nsamples,
                    warmup_steps=100,
                    num_chains=nwalkers,
                    initial_params=theta_0_mcmc,
                    mp_context='spawn')
        mcmc.run()

        with open('savemcmc_{}.pkl'.format(ind), 'wb') as f:
            mcmc.sampler = None
            mcmc.kernel.potential_fn = None
            pickle.dump(mcmc, f)
    return mcmc, objective
Пример #10
0
def run_hmc(
    x_data,
    y_data,
    model,
    num_samples=1000,
    warmup_steps=200,
):
    """
    Runs NUTS
    returns: samples
    """
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel,
                num_samples=num_samples,
                warmup_steps=warmup_steps)
    mcmc.run(x_data, y_data)
    hmc_samples = {
        k: v.detach().cpu().numpy()
        for k, v in mcmc.get_samples().items()
    }
    hmc_samples["linear.weight"] = hmc_samples["linear.weight"].reshape(
        num_samples, -1)
    return hmc_samples
Пример #11
0
    def update(self, logits, labels):
        """ Performs an update given new observations.

        Args:
            logits: tensor ; shape (batch_size, num_classes)
            labels: tensor ; shape (batch_size, )
        """
        assert len(labels.shape) == 1, 'Got label tensor with shape {} -- labels must be dense'.format(labels.shape)
        assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format(logits.shape)
        assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \
            .format(logits.shape[0], labels.shape[0])

        print('delta constraint::')
        print(self.delta_constraint)
        self.timestep += 1

        logits = logits.detach().clone().requires_grad_()
        labels = labels.detach().clone()

        batch_size = labels.shape[0]
        print('----| Updating HBC model\n--------| Got a batch size of: {}'.format(batch_size))

        # TODO: Update prior (for sequential)
        # self._update_prior()
        # print('--------| Updated priors: {}'.format(self.prior_params))

        print('--------| Running inference ')
        nuts_kernel = NUTS(hbc_model, **self.NUTS_params)
        self.mcmc = MCMC(nuts_kernel, **self.mcmc_params, disable_progbar=False)
        print('.')
        self.mcmc.run(self.prior_params, logits, labels, self.delta_constraint)
        print('..')

        #  TODO: update posterior (for sequential)
        # self._update_posterior(posterior_samples)

        return self.mcmc
Пример #12
0
def infer(args, model, t, yt):
    nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.nsamples,
                warmup_steps=args.nwarmups,
                num_chains=args.nchains)
    mcmc.run(model, t, yt)
    mcmc.summary(prob=0.95)
    return mcmc
Пример #13
0
    def test_inference_data_constant_data(self):
        import pyro.distributions as dist
        from pyro.infer import MCMC, NUTS

        x1 = 10
        x2 = 12
        y1 = torch.randn(10)

        def model_constant_data(x, y1=None):
            _x = pyro.sample("x", dist.Normal(1, 3))
            pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

        nuts_kernel = NUTS(model_constant_data)
        mcmc = MCMC(nuts_kernel, num_samples=10)
        mcmc.run(x=x1, y1=y1)
        posterior = mcmc.get_samples()
        posterior_predictive = Predictive(model_constant_data, posterior)(x1)
        predictions = Predictive(model_constant_data, posterior)(x2)
        inference_data = from_pyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            constant_data={"x1": x1},
            predictions_constant_data={"x2": x2},
        )
        test_dict = {
            "posterior": ["x"],
            "posterior_predictive": ["y1"],
            "sample_stats": ["diverging"],
            "log_likelihood": ["y1"],
            "predictions": ["y1"],
            "observed_data": ["y1"],
            "constant_data": ["x1"],
            "predictions_constant_data": ["x2"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails
Пример #14
0
def test_neals_funnel_smoke():
    dim = 10

    def model():
        y = pyro.sample('y', dist.Normal(0, 3))
        with pyro.plate("D", dim):
            pyro.sample('x', dist.Normal(0, torch.exp(y/2)))

    guide = AutoIAFNormal(model)
    svi = SVI(model, guide,  optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(1000):
        svi.step()

    neutra = NeuTraReparam(guide)
    model = neutra.reparam(model)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run()
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2))
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
Пример #15
0
def test_neals_funnel_smoke(jit):
    dim = 10

    guide = AutoStructured(
        neals_funnel,
        conditionals={
            "y": "normal",
            "x": "mvn"
        },
        dependencies={"x": {
            "y": "linear"
        }},
    )
    Elbo = JitTrace_ELBO if jit else Trace_ELBO
    svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Elbo())
    for _ in range(1000):
        try:
            svi.step(dim=dim)
        except SystemError as e:
            if "returned a result with an error set" in str(e):
                pytest.xfail(reason="PyTorch jit bug")
            else:
                raise e from None

    rep = StructuredReparam(guide)
    model = rep.reparam(neals_funnel)
    nuts = NUTS(model, max_tree_depth=3, jit_compile=jit)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run(dim)
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites,
    # not uniformly at -max_plate_nesting-1; hence the unsqueeze.
    samples = {k: v.unsqueeze(1) for k, v in samples.items()}
    transformed_samples = rep.transform_samples(samples)
    assert isinstance(transformed_samples, dict)
    assert set(transformed_samples) == {"x", "y"}
Пример #16
0
def _infer_hmc(args, data, model, init_values={}):
    logging.info("Running inference...")
    kernel = NUTS(
        model,
        full_mass=[("R0", "rho")],
        max_tree_depth=args.max_tree_depth,
        init_strategy=init_to_value(values=init_values),
        jit_compile=args.jit,
        ignore_jit_warnings=True,
    )

    # We'll define a hook_fn to log potential energy values during inference.
    # This is helpful to diagnose whether the chain is mixing.
    energies = []

    def hook_fn(kernel, *unused):
        e = float(kernel._potential_energy_last)
        energies.append(e)
        if args.verbose:
            logging.info("potential = {:0.6g}".format(e))

    mcmc = MCMC(
        kernel,
        hook_fn=hook_fn,
        num_samples=args.num_samples,
        warmup_steps=args.warmup_steps,
    )
    mcmc.run(args, data)
    mcmc.summary()
    if args.plot:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(6, 3))
        plt.plot(energies)
        plt.xlabel("MCMC step")
        plt.ylabel("potential energy")
        plt.title("MCMC energy trace")
        plt.tight_layout()

    samples = mcmc.get_samples()
    return samples
Пример #17
0
def main():
    start = time.time()
    pyro.clear_param_store()

    # the kernel we will use
    hmc_kernel = HMC(conditioned_model, step_size=0.1)

    # the sampler which will run the kernel
    mcmc = MCMC(hmc_kernel, num_samples=14000, warmup_steps=100)

    # the .run method accepts as parameter the same parameters our model function uses
    mcmc.run(model, data)
    end = time.time()
    print('Time taken ', end - start, ' seconds')

    sample_dict = mcmc.get_samples(num_samples=5000)

    plt.figure(figsize=(10, 7))
    sns.distplot(sample_dict['latent_fairness'].numpy(), color="orange")
    plt.xlabel("Observed probability value")
    plt.ylabel("Observed frequency")
    plt.show()

    mcmc.summary(prob=0.95)
Пример #18
0
def main(args):
    # define which MCMC algorithm to run (proposal, rejection, etc.)
    # this is captured by the notion of a "kernel"
    # NUTS: No-U-Turn Sampler kernel, which provides an efficient and convenient way
    # to run Hamiltonian Monte Carlo. The number of steps taken by the
    # integrator is dynamically adjusted on each call to ``sample`` to ensure
    # an optimal length for the Hamiltonian trajectory [1]. As such, the samples
    # generated will typically have lower autocorrelation than those generated
    # by the :class:`~pyro.infer.mcmc.HMC` kernel.

    nuts_kernel = NUTS(conditioned_model)

    # MCMC is the wrapper around the actual algorithm variant, you call  .run on it
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)

    data = construct_data()

    mcmc.run(model(data["J"]), data["sigma"], data["y"])
    mcmc.summary(prob=0.5)
Пример #19
0
        tavg_norm_noauto_3d, tavg_raw_all_3d, tavg_raw_noauto_3d
    ] = pickle.load(f)

tm.mtype = 'group'
tm.target = 'self'  # 'self','targ','avg'
tm.dtype = 'norm'  # 'norm','raw'
tm.auto = 'all'  # 'noauto','all'
tm.stickbreak = False
tm.optim = pyro.optim.Adam({'lr': 0.0005, 'betas': [0.8, 0.99]})
tm.elbo = TraceEnum_ELBO(max_plate_nesting=1)

tm.K = 3

pyro.clear_param_store()
pyro.set_rng_seed(99)

# #declare dataset to be modeled
# dtname = 't{}_{}_{}_3d'.format(target, dtype, auto)
# print("running MCMC with: {}".format(dtname))
# data = globals()[dtname]

nuts_kernel = NUTS(tm.model)

mcmc = MCMC(nuts_kernel, num_samples=5000, warmup_steps=1000)
mcmc.run(tself_norm_all_3d)

posterior_samples = mcmc.get_samples()

abc = az.from_pyro(mcmc, log_likelihood=True)
az.stats.waic(abc.posterior.weights)
Пример #20
0
    # sigma = dist.Uniform(0., 5.).sample()
    # sigma = dist.Uniform(sigma_loc, 5.).sample()
    # sigma = dist.Normal(sigma_loc, 0.2).sample()
    pyro.sample("obs", dist.Normal(mean, sigma), obs=obserations)


dims = 4
num_samples = 100

# generate observations
x = torch.rand(dims, num_samples)
noise = torch.distributions.Normal(torch.tensor([0.] * num_samples),
                                   torch.tensor([0.2] *
                                                num_samples)).rsample()
s, fm, Zn, Vr = x
a, b, c, d = 1.5, 1.8, 2.1, 2.3
# a, b, c, d = 1., 1., 1., 1.
obserations = s * fm**a * Zn**b * c / Vr**d + noise[0]

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=400)
mcmc.run(x, dims, obserations)

hmc_samples = {
    k: v.detach().cpu().numpy()
    for k, v in mcmc.get_samples().items()
}
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")
Пример #21
0
from pyro.primitives import sample
import torch
import numpy as np
import pyro
from pyro import sample
from pyro.infer import NUTS, MCMC
from pyro.distributions import Normal
from matplotlib import pyplot as plt


def bad():
    x = sample('x', pyro.distributions.Normal(0, 1))
    for i in range(10):
        x = sample('x', pyro.distributions.Normal(x, 3))
    x


nuts_kernel = NUTS(bad)

mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
mcmc.run()
mcmc.summary()
samples = mcmc.get_samples()

print(samples.keys())

fig, ax = plt.subplots()
ax.hist(np.array(samples["x"]), bins=50)
plt.show()
Пример #22
0
    plt.ylabel('Density')
    plt.tight_layout()
    plt.savefig('./assets/logit_trans.pdf', dpi=600)


if __name__ == '__main__':
    # create the params of NB distribution
    alpha = torch.tensor(args.alpha)
    beta = torch.tensor(args.beta)
    r = torch.tensor(args.r)
    data = torch.tensor([12, 11, 6, 12, 11, 0, 4, 6, 5, 6])

    nb_post = NB_Post(alpha, beta, args.r)
    # create hmc and mcmc object and sample
    hmc_kernel = HMC(nb_post.model, step_size=args.step_size, num_steps=args.num_steps)
    mcmc = MCMC(hmc_kernel, num_samples=args.num_samples, warmup_steps=args.warm_steps)

    # sample the posterior
    mcmc.run(data, args.logit)
    if args.logit:
        param = 'eta'
        posterior_samples = mcmc.get_samples()[param]
        # logit transform
        posterior_samples = torch.exp(posterior_samples) / (1. + torch.exp(posterior_samples))
        # plot the estimated posterior density
        plot_logit_density(posterior_samples)
    else:
        param = 'p'
        posterior_samples = mcmc.get_samples()[param]
        poster_alpha = (alpha + data.sum()).numpy()
        poster_beta = (len(data) * r + beta).numpy()
Пример #23
0
    def fit_mcmc(self, **options):
        r"""
        Runs NUTS inference to generate posterior samples.

        This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
        :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
        attribute on completion.

        This uses an asymptotically exact enumeration-based model when
        ``num_quant_bins > 1``, and a cheaper moment-matched approximate model
        when ``num_quant_bins == 1``.

        :param \*\*options: Options passed to
            :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
            pulled out and have special meaning.
        :param int num_samples: Number of posterior samples to draw via mcmc.
            Defaults to 100.
        :param int max_tree_depth: (Default 5). Max tree depth of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
        :param full_mass: Specification of mass matrix of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
            over global random variables.
        :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
            of an arrowhead matrix versus simply as a block. Defaults to False.
        :param int num_quant_bins: If greater than 1, use asymptotically exact
            inference via local enumeration over this many quantization bins.
            If equal to 1, use continuous-valued relaxed approximate inference.
            Note that computational cost is exponential in `num_quant_bins`.
            Defaults to 1 for relaxed inference.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
            Defaults to True.
        :param int haar_full_mass: Number of low frequency Haar components to
            include in the full mass matrix. If ``haar=False`` then this is
            ignored. Defaults to 10.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
        :rtype: ~pyro.infer.mcmc.api.MCMC
        """
        _require_double_precision()

        # Parse options, saving some for use in .predict().
        num_samples = options.setdefault("num_samples", 100)
        num_chains = options.setdefault("num_chains", 1)
        self.num_quant_bins = options.pop("num_quant_bins", 1)
        assert isinstance(self.num_quant_bins, int)
        assert self.num_quant_bins >= 1
        self.relaxed = self.num_quant_bins == 1

        # Setup Haar wavelet transform.
        haar = options.pop("haar", False)
        haar_full_mass = options.pop("haar_full_mass", 10)
        full_mass = options.pop("full_mass", self.full_mass)
        assert isinstance(haar, bool)
        assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
        assert isinstance(full_mass, (bool, list))
        haar_full_mass = min(haar_full_mass, self.duration)
        if not haar:
            haar_full_mass = 0
        if full_mass is True:
            haar_full_mass = 0  # No need to split.
        elif haar_full_mass >= self.duration:
            full_mass = True  # Effectively full mass.
            haar_full_mass = 0
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
        if haar_full_mass:
            assert full_mass and isinstance(full_mass, list)
            full_mass = full_mass[:]
            full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        init_strategy = init_to_generated(
            generate=functools.partial(self._heuristic, haar, **heuristic_options))

        # Configure a kernel.
        logger.info("Running inference...")
        model = self._relaxed_model if self.relaxed else self._quantized_model
        if haar:
            model = haar.reparam(model)
        kernel = NUTS(model,
                      full_mass=full_mass,
                      init_strategy=init_strategy,
                      max_plate_nesting=self.max_plate_nesting,
                      jit_compile=options.pop("jit_compile", False),
                      jit_options=options.pop("jit_options", None),
                      ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
                      target_accept_prob=options.pop("target_accept_prob", 0.8),
                      max_tree_depth=options.pop("max_tree_depth", 5))
        if options.pop("arrowhead_mass", False):
            kernel.mass_matrix_adapter = ArrowheadMassMatrix()

        # Run mcmc.
        options.setdefault("disable_validation", None)
        mcmc = MCMC(kernel, **options)
        mcmc.run()
        self.samples = mcmc.get_samples()
        if haar:
            haar.aux_to_user(self.samples)

        # Unsqueeze samples to align particle dim for use in poutine.condition.
        # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
        model = self._relaxed_model if self.relaxed else self._quantized_model
        self.samples = align_samples(self.samples, model,
                                     particle_dim=-1 - self.max_plate_nesting)
        assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return mcmc  # E.g. so user can run mcmc.summary().
Пример #24
0
        sys.stderr.write("Requires Python 3\n")

    genr = Decoder()
    genr.load_state_dict(torch.load('gaae-decd-1024.tch'))
    genr.eval()

    data = qPCRData('second.txt', randomize=False, test=False)

    # Do it with CUDA if possible.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        genr.cuda()

    model = GeneratorModel(genr)

    nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=True)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=2000)
    for batch in data.batches(btchsz=8192, randomize=False, test=False):
        obs = batch[:, 90:].to(device)
        mcmc.run(obs)
        z = mcmc.get_samples()['z']
        # Propagate forward and sample observable 'x'.
        with torch.no_grad():
            mu, sd = genr(z)
        for i in range(batch.shape[0]):
            x = Normal(mu[:, i, :90], sd[:, i, :90]).sample()
            orig = batch[i, 90:].expand([1000, 45])
            out = torch.cat([x, orig], dim=1)
            np.savetxt(sys.stdout, out.cpu().numpy(), fmt='%.4f')
Пример #25
0
        rate = pyro.sample("small_spike_rate_post_spike",
                           dist.Uniform(0.4, 0.9))
        for k in range(12 - small_spike_peakStart - 5):
            sellprices.append(rate * basePrice)
            rate -= 0.03
            rate -= pyro.sample("small_spike_final_dec_%d" % k,
                                dist.Uniform(0., 0.02))
        sellprices = torch.ceil(torch.stack(sellprices))
        print("Small spike sellprices: ", sellprices)
    else:
        raise ValueError("Invalid nextPattern %d" % nextPattern)
    pyro.sample("obs", dist.Delta(sellprices).to_event(1))


if __name__ == "__main__":
    calculate_turnip_prices()

    conditioned_model = poutine.condition(calculate_turnip_prices,
                                          data={
                                              "obs":
                                              torch.tensor([
                                                  87., 83., 79., 75., 72., 68.,
                                                  64., 106., 115., 144., 185.,
                                                  138.
                                              ])
                                          })
    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100, num_chains=1)
    mcmc.run()
    mcmc.summary(prob=0.5)
Пример #26
0
class BayesianVSCalibrator:
    """ This class implements the Bayesian VS calibrator, with bias.
    Performs inference using NUTS.
    """
    def __init__(self, prior_params, num_classes, **kwargs):
        self.num_classes = num_classes
        # Inference parameters
        self.NUTS_params = {
            'adapt_step_size': kwargs.pop('adapt_step_size', True),
            'target_accept_prob': kwargs.pop('target_accept_prob', 0.8),
            'max_plate_nesting': 1
        }
        self.mcmc_params = {
            'num_samples': kwargs.pop('num_samples', 250),
            'warmup_steps': kwargs.pop('num_warmup', 1000),
            'num_chains': kwargs.pop('num_chains', 4)
        }

        # Prior parameters on beta / delta ; assumes each weight/bias is i.i.d from its respective distribution.
        self.prior_params = {
            'mu_beta':
            torch.empty(self.num_classes).fill_(prior_params['mu_beta']),
            'sigma_beta':
            torch.empty(self.num_classes).fill_(prior_params['sigma_beta']),
            'mu_delta':
            torch.empty(self.num_classes).fill_(prior_params['mu_delta']),
            'sigma_delta':
            torch.empty(self.num_classes).fill_(prior_params['sigma_delta'])
        }

        # Posterior parameters after ADF
        # TODO
        self.posterior_params = {'mu_beta': None, 'sigma_beta': None}

        # Drift parameters for sequential updating
        self.sigma_drift = kwargs.pop('sigma_drift', 0.0)

        # Tracking params
        # TODO: Prior/posterior trace
        self.timestep = 0
        self.mcmc = None  # Contains the most recent Pyro MCMC api object
        self.verbose = kwargs.pop('verbose', True)

        if self.verbose:
            print('\nInitializing VS model:\n'
                  '----| Prior: {} \n----| Inference Method: NUTS \n'
                  '----| MCMC parameters: {}'
                  ''.format(prior_params, self.mcmc_params))

    def update(self, logits, labels):
        """ Performs an update given new observations.

        Args:
            logits: tensor ; shape (batch_size, num_classes)
            labels: tensor ; shape (batch_size, )
        """
        assert len(
            labels.shape
        ) == 1, 'Got label tensor with shape {} -- labels must be dense'.format(
            labels.shape)
        assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format(
            logits.shape)
        assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \
            .format(logits.shape[0], labels.shape[0])

        logits = logits.detach().clone().requires_grad_()
        labels = labels.detach().clone()

        batch_size = labels.shape[0]
        if self.verbose:
            print(
                '----| Updating HBC model\n--------| Got a batch size of: {}'.
                format(batch_size))

        # TODO
        # self._update_prior_params()
        if self.verbose:
            print('--------| Updated priors: {}'.format(self.prior_params))
            print('--------| Running inference ')
        nuts_kernel = NUTS(bvs_model, **self.NUTS_params)
        self.mcmc = MCMC(
            nuts_kernel, **self.mcmc_params,
            disable_progbar=not self.verbose)  # Progbar if verbose
        self.mcmc.run(self.prior_params, logits, labels)

        # TODO
        # self._update_posterior_params()
        self.timestep += 1

        return self.mcmc

    def _update_prior_params(self):
        """ Updates the prior parameters using the ADF posterior from the previous timestep, plus the drift.

        If this is the first batch, i.e. timestep == 0, do nothing.
        """
        # TODO
        if self.timestep > 0:
            self.prior_params['mu_beta'] = self.posterior_params['mu_beta']
            self.prior_params['sigma_beta'] = self.posterior_params[
                'sigma_beta'] + self.sigma_drift

    def _update_posterior_params(self):
        """ Fits a normal distribution to the current beta samples using moment matching.
        """
        # TODO
        beta_samples = self.get_current_posterior_samples()
        self.posterior_params['mu_beta'] = beta_samples.mean().item()
        self.posterior_params['sigma_beta'] = beta_samples.std().item()

    def get_current_posterior_samples(self):
        """ Returns the current posterior samples for beta.
        """
        if self.mcmc is None:
            return None

        return self.mcmc.get_samples()

    def calibrate(self, logit):
        """ Calibrates the given batch of logits using the current posterior samples.

        Args:
            logit: tensor ; shape (batch_size, num_classes)
        """
        # Get beta samples
        beta_samples = self.get_current_posterior_samples()[
            'beta']  # Shape (num_samples, num_classes)
        delta_samples = self.get_current_posterior_samples()[
            'delta']  # Shape (num_samples, num_classes)

        # Get a batch of logits for each sampled parameter vector
        # Shape (num_samples, batch_size, num_classes)
        tempered_logit_samples = beta_samples.view(-1, 1, self.num_classes) * logit + \
                                 delta_samples.view(-1, 1, self.num_classes)

        # Softmax the sampled logits to get sampled probabilities
        prob_samples = softmax(
            tempered_logit_samples,
            dim=2)  # Shape (num_samples, batch_size, num_classes)

        # Average over the sampled probabilities to get Monte Carlo estimate
        calibrated_probs = prob_samples.mean(
            dim=0)  # Shape (batch_size, num_classes)

        return calibrated_probs

    def get_MAP_temperature(self, logits, labels):
        """ Performs MAP estimation using the current prior and given data.
         NB: This should only be called after .update() if used in a sequential setting, as this method
         does not update the prior with sigma_drift.

         See: https://pyro.ai/examples/mle_map.html
         """
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model=bvs_model,
                             guide=MAP_guide,
                             optim=pyro.optim.Adam({'lr': 0.001}),
                             loss=pyro.infer.Trace_ELBO())

        loss = []
        num_steps = 5000
        for _ in range(num_steps):
            loss.append(svi.step(self.prior_params, logits, labels))

        eps = 2e-2
        loss_sddev = np.std(loss[-25:])
        if loss_sddev > eps:
            warnings.warn(
                'MAP optimization may not have converged ; sddev {}'.format(
                    loss_sddev))

        beta_MAP = pyro.param('beta_MAP').detach()
        delta_MAP = pyro.param('delta_MAP').detach()
        return beta_MAP, delta_MAP
y_train_torch = torch.tensor(y_train)



# Clear the parameter storage
pyro.clear_param_store()

# Initialize our No U-Turn Sampler
my_kernel = NUTS(model_normal, 
                 max_tree_depth=7) # a shallower tree helps the algorithm run faster

# Employ the sampler in an MCMC sampling 
# algorithm, and sample 3100 samples. 
# Then discard the first 100
my_mcmc1 = MCMC(my_kernel,
                num_samples=SAMPLE_NUMBER,
                warmup_steps=100)


# Let's time our execution as well
start_time = time.time()

# Run the sampler
my_mcmc1.run(X_train_torch, 
             y_train_torch,
             california.feature_names)

end_time = time.time()

print(f'Inference ran for {round((end_time -  start_time)/60.0, 2)} minutes')
Пример #28
0
    # data = circles
    data = regression

    pyro.set_rng_seed(1)

    params = {
        'in_features': data.tensors[0].reshape(N_SAMPLES, -1).shape[1],
        'out_features': 1,
        'hidden_features': 50,
        'n_layers': 1,
        'dropout': None,
        'device': DEVICE
    }
    target_std = 1e-1
    model_b = MultilayerBayesian(**params, target_std=target_std)

    DO_LOAD = False

    samples = None
    if DO_LOAD:
        samples = torch.load('samples_reg.pth', map_location='cpu')
    else:
        pyro.clear_param_store()
        nuts_kernel = NUTS(model_b)

        mcmc = MCMC(nuts_kernel, num_samples=300, num_chains=5)
        mcmc.run(data.tensors[0].reshape(N_SAMPLES, -1).float().to(DEVICE),
                 data.tensors[1].reshape(N_SAMPLES, -1).float().to(DEVICE))
        samples = mcmc.get_samples()
        torch.save(samples, 'samples_reg.pth')
Пример #29
0
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    # In this model, we illustrate how to use MCMC with general potential_fn.
    init_params, potential_fn, transforms, _ = initialize_model(
        fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
        jit_compile=args.jit, skip_jit_warnings=True)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["alpha"],
                                   player_names=player_names,
                                   transforms={"alpha": torch.sigmoid},
                                   diagnostics=True,
                                   group_by_chain=True)["alpha"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                    baseball_dataset)



for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")



from pyro.infer import MCMC, NUTS


nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}




for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")




sites = ["a", "bA", "bR", "bAR", "sigma"]