def Inference_MCMC(model, data, polls, n_samples=500, n_warmup=500, n_chains=1, max_tree_depth=6): nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True, max_tree_depth=max_tree_depth) mcmc = MCMC(nuts_kernel, num_samples=n_samples, warmup_steps=n_warmup, num_chains=n_chains) mcmc.run(data, polls) # the samples that were not rejected; # actual samples from the posterior dist posterior_samples = mcmc.get_samples() # turning to a dict hmc_samples = { k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items() } return posterior_samples, hmc_samples
def _infer_hmc(args, data, model, init_values={}): logging.info("Running inference...") kernel = NUTS(model, full_mass=[("R0", "rho")], max_tree_depth=args.max_tree_depth, init_strategy=init_to_value(values=init_values), jit_compile=args.jit, ignore_jit_warnings=True) # We'll define a hook_fn to log potential energy values during inference. # This is helpful to diagnose whether the chain is mixing. energies = [] def hook_fn(kernel, *unused): e = float(kernel._potential_energy_last) energies.append(e) if args.verbose: logging.info("potential = {:0.6g}".format(e)) mcmc = MCMC(kernel, hook_fn=hook_fn, num_samples=args.num_samples, warmup_steps=args.warmup_steps) mcmc.run(args, data) mcmc.summary() if args.plot: import matplotlib.pyplot as plt plt.figure(figsize=(6, 3)) plt.plot(energies) plt.xlabel("MCMC step") plt.ylabel("potential energy") plt.title("MCMC energy trace") plt.tight_layout() samples = mcmc.get_samples() return samples
def mcmc(self, data: torch.Tensor, y: torch.Tensor, tensorboard: bool, log_dir: str): """ Perform Markov-Chain Monte-Carlo sampling on the (unknown) posterior. Parameters ---------- data_input : np.ndarray, shape=(n_samples, n_features) NumPy 2-D array with data input. y : np.ndarray, shape=(n_samples,) NumPy array with ground truth labels as 1-D vector (binary). """ if tensorboard: writer = SummaryWriter(log_dir=log_dir) distribution = defaultdict(list) def log(kernel, samples, stage, i): """ Log after each MCMC iteration """ # loop through all sites and log their value as well as the underlying distribution # approximated by a Gaussian for key, value in samples.items(): distribution[key].append(value) stacked = torch.stack(distribution[key], dim=0) mean, scale = torch.mean(stacked, dim=0), torch.std(stacked, dim=0) for d, x in enumerate(value): writer.add_scalar("%s_%s_%d" % (stage, key, d), x, i) writer.add_scalar("%s_%s_mean_%d" % (stage, key, d), mean[d], i) writer.add_scalar("%s_%s_scale_%d" % (stage, key, d), scale[d], i) writer.add_histogram( "%s_histogram_%s_%d" % (stage, key, d), stacked[:, d], i) # if logging is not requested, return empty lambda else: log = lambda kernel, samples, stage, i: None # set up MCMC kernel kernel = NUTS(self.model) # initialize MCMC sampler and run sampling algorithm mcmc = MCMC(kernel, num_samples=self.mcmc_steps, warmup_steps=self.mcmc_warmup, num_chains=self.mcmc_chains, hook_fn=log) mcmc.run(data.float(), y.float()) # get samples from MCMC chains and store weights samples = mcmc.get_samples() self.mcmc_model = samples if tensorboard: writer.close()
def __init__(self, model, data, covariates=None, *, num_warmup=1000, num_samples=1000, num_chains=1, dense_mass=False, jit_compile=False, max_tree_depth=10): assert data.size(-2) == covariates.size(-2) super().__init__() self.model = model max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {}) self.max_plate_nesting = max(max_plate_nesting, 1) # force a time plate kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True, max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting) mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(data, covariates) # conditions to compute rhat if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2): mcmc.summary() # inspect the model with particles plate = 1, so that we can reshape samples to # add any missing plate dim in front. with poutine.trace() as tr: with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1): model(data, covariates) self._trace = tr.trace self._samples = mcmc.get_samples() self._num_samples = num_samples * num_chains for name, node in list(self._trace.nodes.items()): if name not in self._samples: del self._trace.nodes[name]
def test_multiple_observed_rv(self): import pyro.distributions as dist from pyro.infer import MCMC, NUTS y1 = torch.randn(10) y2 = torch.randn(10) def model_example_multiple_obs(y1=None, y2=None): x = pyro.sample("x", dist.Normal(1, 3)) pyro.sample("y1", dist.Normal(x, 1), obs=y1) pyro.sample("y2", dist.Normal(x, 1), obs=y2) nuts_kernel = NUTS(model_example_multiple_obs) mcmc = MCMC(nuts_kernel, num_samples=10) mcmc.run(y1=y1, y2=y2) inference_data = from_pyro(mcmc) test_dict = { "posterior": ["x"], "sample_stats": ["diverging"], "log_likelihood": ["y1", "y2"], "observed_data": ["y1", "y2"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails assert not hasattr(inference_data.sample_stats, "log_likelihood")
def run_mcmc(self, x: torch.Tensor, y: torch.Tensor, *, num_samples: int = 1000, warmup_steps: int = 200) -> MCMC: nuts_kernel = NUTS(self) mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps) mcmc.run(x, y) return mcmc
def run_hmc(args, model): nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples) mcmc.run(args.param_a, args.param_b) mcmc.summary() return mcmc
def main(args): nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(model, data.sigma, data.y) mcmc.summary(prob=0.5)
def infer(args, model, t, yt): nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit) mcmc = MCMC(nuts_kernel, num_samples=args.nsamples, warmup_steps=args.nwarmups, num_chains=args.nchains) mcmc.run(model, t, yt) mcmc.summary(prob=0.95) return mcmc
def pyro_noncentered_schools(data, draws, chains): """Non-centered eight schools implementation in Pyro.""" import torch from pyro.infer import MCMC, NUTS y = torch.from_numpy(data["y"]).float() sigma = torch.from_numpy(data["sigma"]).float() nuts_kernel = NUTS(_pyro_noncentered_model) posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains) posterior.run(data["J"], sigma, y) # This block lets the posterior be pickled posterior.sampler = None return posterior
def tomtom_mcmc(data, seed, nsample=5000, burnin=1000): pyro.clear_param_store() pyro.set_rng_seed(seed) # #declare dataset to be modeled # dtname = 't{}_{}_{}_3d'.format(target, dtype, auto) # print("running MCMC with: {}".format(dtname)) # data = globals()[dtname] nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=nsample, warmup_steps=burnin) mcmc.run(data) posterior_samples = mcmc.get_samples() return posterior_samples
def numpyro_schools_model(data, draws, chains): """Centered eight schools implementation in NumPyro.""" from jax.random import PRNGKey from numpyro.infer import MCMC, NUTS mcmc = MCMC( NUTS(_numpyro_noncentered_model), num_warmup=draws, num_samples=draws, num_chains=chains, chain_method="sequential", ) mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data) # This block lets the posterior be pickled mcmc.sampler._sample_fn = None # pylint: disable=protected-access return mcmc
def test_neals_funnel_smoke(jit): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step(dim) neutra = NeuTraReparam(guide.requires_grad_(False)) model = neutra.reparam(neals_funnel) nuts = NUTS(model, jit_compile=jit) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run(dim) samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample( samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples
def main(args): # define which MCMC algorithm to run (proposal, rejection, etc.) # this is captured by the notion of a "kernel" # NUTS: No-U-Turn Sampler kernel, which provides an efficient and convenient way # to run Hamiltonian Monte Carlo. The number of steps taken by the # integrator is dynamically adjusted on each call to ``sample`` to ensure # an optimal length for the Hamiltonian trajectory [1]. As such, the samples # generated will typically have lower autocorrelation than those generated # by the :class:`~pyro.infer.mcmc.HMC` kernel. nuts_kernel = NUTS(conditioned_model) # MCMC is the wrapper around the actual algorithm variant, you call .run on it mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) data = construct_data() mcmc.run(model(data["J"]), data["sigma"], data["y"]) mcmc.summary(prob=0.5)
def sample_model(chat, mhat, varpihat, sigmac, sigmam, sigmavarpi, dustco_c, dustco_m, theta_0_mcmc, nsamples=100, nwalkers=1): objective = Objective(chat, mhat, varpihat, sigmac, sigmam, sigmavarpi, dustco_c, dustco_m) #print(objective.logjoint()) objective.logjoint() #nuts_kernel = NUTS(objective.logjoint, jit_compile=True, ignore_jit_warnings=True) #mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=100, num_chains=2, mp_context='spawn') try: with open('savemcmc_{}.pkl'.format(ind), 'rb') as f: mcmc = pickle.load(f) except IOError: nuts_kernel = NUTS(objective.logjoint, jit_compile=True, ignore_jit_warnings=False) mcmc = MCMC(nuts_kernel, num_samples=nsamples, warmup_steps=100, num_chains=nwalkers, initial_params=theta_0_mcmc, mp_context='spawn') mcmc.run() with open('savemcmc_{}.pkl'.format(ind), 'wb') as f: mcmc.sampler = None mcmc.kernel.potential_fn = None pickle.dump(mcmc, f) return mcmc, objective
def run_hmc( x_data, y_data, model, num_samples=1000, warmup_steps=200, ): """ Runs NUTS returns: samples """ nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps) mcmc.run(x_data, y_data) hmc_samples = { k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items() } hmc_samples["linear.weight"] = hmc_samples["linear.weight"].reshape( num_samples, -1) return hmc_samples
def test_inference_data_constant_data(self): import pyro.distributions as dist from pyro.infer import MCMC, NUTS x1 = 10 x2 = 12 y1 = torch.randn(10) def model_constant_data(x, y1=None): _x = pyro.sample("x", dist.Normal(1, 3)) pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1) nuts_kernel = NUTS(model_constant_data) mcmc = MCMC(nuts_kernel, num_samples=10) mcmc.run(x=x1, y1=y1) posterior = mcmc.get_samples() posterior_predictive = Predictive(model_constant_data, posterior)(x1) predictions = Predictive(model_constant_data, posterior)(x2) inference_data = from_pyro( mcmc, posterior_predictive=posterior_predictive, predictions=predictions, constant_data={"x1": x1}, predictions_constant_data={"x2": x2}, ) test_dict = { "posterior": ["x"], "posterior_predictive": ["y1"], "sample_stats": ["diverging"], "log_likelihood": ["y1"], "predictions": ["y1"], "observed_data": ["y1"], "constant_data": ["x1"], "predictions_constant_data": ["x2"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails
def test_neals_funnel_smoke(jit): dim = 10 guide = AutoStructured( neals_funnel, conditionals={ "y": "normal", "x": "mvn" }, dependencies={"x": { "y": "linear" }}, ) Elbo = JitTrace_ELBO if jit else Trace_ELBO svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Elbo()) for _ in range(1000): try: svi.step(dim=dim) except SystemError as e: if "returned a result with an error set" in str(e): pytest.xfail(reason="PyTorch jit bug") else: raise e from None rep = StructuredReparam(guide) model = rep.reparam(neals_funnel) nuts = NUTS(model, max_tree_depth=3, jit_compile=jit) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run(dim) samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, # not uniformly at -max_plate_nesting-1; hence the unsqueeze. samples = {k: v.unsqueeze(1) for k, v in samples.items()} transformed_samples = rep.transform_samples(samples) assert isinstance(transformed_samples, dict) assert set(transformed_samples) == {"x", "y"}
def test_neals_funnel_smoke(): dim = 10 def model(): y = pyro.sample('y', dist.Normal(0, 3)) with pyro.plate("D", dim): pyro.sample('x', dist.Normal(0, torch.exp(y/2))) guide = AutoIAFNormal(model) svi = SVI(model, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step() neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run() samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples
def main(): start = time.time() pyro.clear_param_store() # the kernel we will use hmc_kernel = HMC(conditioned_model, step_size=0.1) # the sampler which will run the kernel mcmc = MCMC(hmc_kernel, num_samples=14000, warmup_steps=100) # the .run method accepts as parameter the same parameters our model function uses mcmc.run(model, data) end = time.time() print('Time taken ', end - start, ' seconds') sample_dict = mcmc.get_samples(num_samples=5000) plt.figure(figsize=(10, 7)) sns.distplot(sample_dict['latent_fairness'].numpy(), color="orange") plt.xlabel("Observed probability value") plt.ylabel("Observed frequency") plt.show() mcmc.summary(prob=0.95)
def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] logging.info("Original Dataset:") logging.info(baseball_dataset) # (1) Full Pooling Model # In this model, we illustrate how to use MCMC with general potential_fn. init_params, potential_fn, transforms, _ = initialize_model( fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains, jit_compile=args.jit, skip_jit_warnings=True) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) samples_fully_pooled = mcmc.get_samples() logging.info("\nModel: Fully Pooled") logging.info("===================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset) evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset) # (2) No Pooling Model nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_not_pooled = mcmc.get_samples() logging.info("\nModel: Not Pooled") logging.info("=================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset) evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset) # (3) Partially Pooled Model nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_partially_pooled = mcmc.get_samples() logging.info("\nModel: Partially Pooled") logging.info("=======================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset) evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset) # (4) Partially Pooled with Logit Model nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_partially_pooled_logit = mcmc.get_samples() logging.info("\nModel: Partially Pooled with Logit") logging.info("==================================") logging.info("\nSigmoid(alpha):") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["alpha"], player_names=player_names, transforms={"alpha": torch.sigmoid}, diagnostics=True, group_by_chain=True)["alpha"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset) evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset)
for site, values in summary(svi_samples).items(): print("Site: {}".format(site)) print(values, "\n") from pyro.infer import MCMC, NUTS nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200) mcmc.run(is_cont_africa, ruggedness, log_gdp) hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()} for site, values in summary(hmc_samples).items(): print("Site: {}".format(site)) print(values, "\n") sites = ["a", "bA", "bR", "bAR", "sigma"]
max_tree_depth=7) # a shallower tree helps the algorithm run faster # Employ the sampler in an MCMC sampling # algorithm, and sample 3100 samples. # Then discard the first 100 my_mcmc1 = MCMC(my_kernel, num_samples=SAMPLE_NUMBER, warmup_steps=100) # Let's time our execution as well start_time = time.time() # Run the sampler my_mcmc1.run(X_train_torch, y_train_torch, california.feature_names) end_time = time.time() print(f'Inference ran for {round((end_time - start_time)/60.0, 2)} minutes') # In[20]: # TO DELETE new_measurements = pd.Series({"model": "MCMC unscaled normal", "measurement": "time", "censored": True, "value": round(end_time - start_time, 2)})
tavg_norm_noauto_3d, tavg_raw_all_3d, tavg_raw_noauto_3d ] = pickle.load(f) tm.mtype = 'group' tm.target = 'self' # 'self','targ','avg' tm.dtype = 'norm' # 'norm','raw' tm.auto = 'all' # 'noauto','all' tm.stickbreak = False tm.optim = pyro.optim.Adam({'lr': 0.0005, 'betas': [0.8, 0.99]}) tm.elbo = TraceEnum_ELBO(max_plate_nesting=1) tm.K = 3 pyro.clear_param_store() pyro.set_rng_seed(99) # #declare dataset to be modeled # dtname = 't{}_{}_{}_3d'.format(target, dtype, auto) # print("running MCMC with: {}".format(dtname)) # data = globals()[dtname] nuts_kernel = NUTS(tm.model) mcmc = MCMC(nuts_kernel, num_samples=5000, warmup_steps=1000) mcmc.run(tself_norm_all_3d) posterior_samples = mcmc.get_samples() abc = az.from_pyro(mcmc, log_likelihood=True) az.stats.waic(abc.posterior.weights)
def fit_mcmc(self, **options): r""" Runs NUTS inference to generate posterior samples. This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples`` attribute on completion. This uses an asymptotically exact enumeration-based model when ``num_quant_bins > 1``, and a cheaper moment-matched approximate model when ``num_quant_bins == 1``. :param \*\*options: Options passed to :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are pulled out and have special meaning. :param int num_samples: Number of posterior samples to draw via mcmc. Defaults to 100. :param int max_tree_depth: (Default 5). Max tree depth of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. :param full_mass: Specification of mass matrix of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass over global random variables. :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head of an arrowhead matrix versus simply as a block. Defaults to False. :param int num_quant_bins: If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuous-valued relaxed approximate inference. Note that computational cost is exponential in `num_quant_bins`. Defaults to 1 for relaxed inference. :param bool haar: Whether to use a Haar wavelet reparameterizer. Defaults to True. :param int haar_full_mass: Number of low frequency Haar components to include in the full mass matrix. If ``haar=False`` then this is ignored. Defaults to 10. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ _require_double_precision() # Parse options, saving some for use in .predict(). num_samples = options.setdefault("num_samples", 100) num_chains = options.setdefault("num_chains", 1) self.num_quant_bins = options.pop("num_quant_bins", 1) assert isinstance(self.num_quant_bins, int) assert self.num_quant_bins >= 1 self.relaxed = self.num_quant_bins == 1 # Setup Haar wavelet transform. haar = options.pop("haar", False) haar_full_mass = options.pop("haar_full_mass", 10) full_mass = options.pop("full_mass", self.full_mass) assert isinstance(haar, bool) assert isinstance(haar_full_mass, int) and haar_full_mass >= 0 assert isinstance(full_mass, (bool, list)) haar_full_mass = min(haar_full_mass, self.duration) if not haar: haar_full_mass = 0 if full_mass is True: haar_full_mass = 0 # No need to split. elif haar_full_mass >= self.duration: full_mass = True # Effectively full mass. haar_full_mass = 0 if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports) if haar_full_mass: assert full_mass and isinstance(full_mass, list) full_mass = full_mass[:] full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims)) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} init_strategy = init_to_generated( generate=functools.partial(self._heuristic, haar, **heuristic_options)) # Configure a kernel. logger.info("Running inference...") model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) kernel = NUTS(model, full_mass=full_mass, init_strategy=init_strategy, max_plate_nesting=self.max_plate_nesting, jit_compile=options.pop("jit_compile", False), jit_options=options.pop("jit_options", None), ignore_jit_warnings=options.pop("ignore_jit_warnings", True), target_accept_prob=options.pop("target_accept_prob", 0.8), max_tree_depth=options.pop("max_tree_depth", 5)) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() # Run mcmc. options.setdefault("disable_validation", None) mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() if haar: haar.aux_to_user(self.samples) # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). model = self._relaxed_model if self.relaxed else self._quantized_model self.samples = align_samples(self.samples, model, particle_dim=-1 - self.max_plate_nesting) assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return mcmc # E.g. so user can run mcmc.summary().
# sigma = dist.Uniform(0., 5.).sample() # sigma = dist.Uniform(sigma_loc, 5.).sample() # sigma = dist.Normal(sigma_loc, 0.2).sample() pyro.sample("obs", dist.Normal(mean, sigma), obs=obserations) dims = 4 num_samples = 100 # generate observations x = torch.rand(dims, num_samples) noise = torch.distributions.Normal(torch.tensor([0.] * num_samples), torch.tensor([0.2] * num_samples)).rsample() s, fm, Zn, Vr = x a, b, c, d = 1.5, 1.8, 2.1, 2.3 # a, b, c, d = 1., 1., 1., 1. obserations = s * fm**a * Zn**b * c / Vr**d + noise[0] nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=400) mcmc.run(x, dims, obserations) hmc_samples = { k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items() } for site, values in summary(hmc_samples).items(): print("Site: {}".format(site)) print(values, "\n")
from pyro.primitives import sample import torch import numpy as np import pyro from pyro import sample from pyro.infer import NUTS, MCMC from pyro.distributions import Normal from matplotlib import pyplot as plt def bad(): x = sample('x', pyro.distributions.Normal(0, 1)) for i in range(10): x = sample('x', pyro.distributions.Normal(x, 3)) x nuts_kernel = NUTS(bad) mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=10) mcmc.run() mcmc.summary() samples = mcmc.get_samples() print(samples.keys()) fig, ax = plt.subplots() ax.hist(np.array(samples["x"]), bins=50) plt.show()
class BayesianVSCalibrator: """ This class implements the Bayesian VS calibrator, with bias. Performs inference using NUTS. """ def __init__(self, prior_params, num_classes, **kwargs): self.num_classes = num_classes # Inference parameters self.NUTS_params = { 'adapt_step_size': kwargs.pop('adapt_step_size', True), 'target_accept_prob': kwargs.pop('target_accept_prob', 0.8), 'max_plate_nesting': 1 } self.mcmc_params = { 'num_samples': kwargs.pop('num_samples', 250), 'warmup_steps': kwargs.pop('num_warmup', 1000), 'num_chains': kwargs.pop('num_chains', 4) } # Prior parameters on beta / delta ; assumes each weight/bias is i.i.d from its respective distribution. self.prior_params = { 'mu_beta': torch.empty(self.num_classes).fill_(prior_params['mu_beta']), 'sigma_beta': torch.empty(self.num_classes).fill_(prior_params['sigma_beta']), 'mu_delta': torch.empty(self.num_classes).fill_(prior_params['mu_delta']), 'sigma_delta': torch.empty(self.num_classes).fill_(prior_params['sigma_delta']) } # Posterior parameters after ADF # TODO self.posterior_params = {'mu_beta': None, 'sigma_beta': None} # Drift parameters for sequential updating self.sigma_drift = kwargs.pop('sigma_drift', 0.0) # Tracking params # TODO: Prior/posterior trace self.timestep = 0 self.mcmc = None # Contains the most recent Pyro MCMC api object self.verbose = kwargs.pop('verbose', True) if self.verbose: print('\nInitializing VS model:\n' '----| Prior: {} \n----| Inference Method: NUTS \n' '----| MCMC parameters: {}' ''.format(prior_params, self.mcmc_params)) def update(self, logits, labels): """ Performs an update given new observations. Args: logits: tensor ; shape (batch_size, num_classes) labels: tensor ; shape (batch_size, ) """ assert len( labels.shape ) == 1, 'Got label tensor with shape {} -- labels must be dense'.format( labels.shape) assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format( logits.shape) assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \ .format(logits.shape[0], labels.shape[0]) logits = logits.detach().clone().requires_grad_() labels = labels.detach().clone() batch_size = labels.shape[0] if self.verbose: print( '----| Updating HBC model\n--------| Got a batch size of: {}'. format(batch_size)) # TODO # self._update_prior_params() if self.verbose: print('--------| Updated priors: {}'.format(self.prior_params)) print('--------| Running inference ') nuts_kernel = NUTS(bvs_model, **self.NUTS_params) self.mcmc = MCMC( nuts_kernel, **self.mcmc_params, disable_progbar=not self.verbose) # Progbar if verbose self.mcmc.run(self.prior_params, logits, labels) # TODO # self._update_posterior_params() self.timestep += 1 return self.mcmc def _update_prior_params(self): """ Updates the prior parameters using the ADF posterior from the previous timestep, plus the drift. If this is the first batch, i.e. timestep == 0, do nothing. """ # TODO if self.timestep > 0: self.prior_params['mu_beta'] = self.posterior_params['mu_beta'] self.prior_params['sigma_beta'] = self.posterior_params[ 'sigma_beta'] + self.sigma_drift def _update_posterior_params(self): """ Fits a normal distribution to the current beta samples using moment matching. """ # TODO beta_samples = self.get_current_posterior_samples() self.posterior_params['mu_beta'] = beta_samples.mean().item() self.posterior_params['sigma_beta'] = beta_samples.std().item() def get_current_posterior_samples(self): """ Returns the current posterior samples for beta. """ if self.mcmc is None: return None return self.mcmc.get_samples() def calibrate(self, logit): """ Calibrates the given batch of logits using the current posterior samples. Args: logit: tensor ; shape (batch_size, num_classes) """ # Get beta samples beta_samples = self.get_current_posterior_samples()[ 'beta'] # Shape (num_samples, num_classes) delta_samples = self.get_current_posterior_samples()[ 'delta'] # Shape (num_samples, num_classes) # Get a batch of logits for each sampled parameter vector # Shape (num_samples, batch_size, num_classes) tempered_logit_samples = beta_samples.view(-1, 1, self.num_classes) * logit + \ delta_samples.view(-1, 1, self.num_classes) # Softmax the sampled logits to get sampled probabilities prob_samples = softmax( tempered_logit_samples, dim=2) # Shape (num_samples, batch_size, num_classes) # Average over the sampled probabilities to get Monte Carlo estimate calibrated_probs = prob_samples.mean( dim=0) # Shape (batch_size, num_classes) return calibrated_probs def get_MAP_temperature(self, logits, labels): """ Performs MAP estimation using the current prior and given data. NB: This should only be called after .update() if used in a sequential setting, as this method does not update the prior with sigma_drift. See: https://pyro.ai/examples/mle_map.html """ pyro.clear_param_store() svi = pyro.infer.SVI(model=bvs_model, guide=MAP_guide, optim=pyro.optim.Adam({'lr': 0.001}), loss=pyro.infer.Trace_ELBO()) loss = [] num_steps = 5000 for _ in range(num_steps): loss.append(svi.step(self.prior_params, logits, labels)) eps = 2e-2 loss_sddev = np.std(loss[-25:]) if loss_sddev > eps: warnings.warn( 'MAP optimization may not have converged ; sddev {}'.format( loss_sddev)) beta_MAP = pyro.param('beta_MAP').detach() delta_MAP = pyro.param('delta_MAP').detach() return beta_MAP, delta_MAP
sys.stderr.write("Requires Python 3\n") genr = Decoder() genr.load_state_dict(torch.load('gaae-decd-1024.tch')) genr.eval() data = qPCRData('second.txt', randomize=False, test=False) # Do it with CUDA if possible. device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cuda': torch.set_default_tensor_type(torch.cuda.FloatTensor) genr.cuda() model = GeneratorModel(genr) nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=True) mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=2000) for batch in data.batches(btchsz=8192, randomize=False, test=False): obs = batch[:, 90:].to(device) mcmc.run(obs) z = mcmc.get_samples()['z'] # Propagate forward and sample observable 'x'. with torch.no_grad(): mu, sd = genr(z) for i in range(batch.shape[0]): x = Normal(mu[:, i, :90], sd[:, i, :90]).sample() orig = batch[i, 90:].expand([1000, 45]) out = torch.cat([x, orig], dim=1) np.savetxt(sys.stdout, out.cpu().numpy(), fmt='%.4f')
if __name__ == '__main__': # create the params of NB distribution alpha = torch.tensor(args.alpha) beta = torch.tensor(args.beta) r = torch.tensor(args.r) data = torch.tensor([12, 11, 6, 12, 11, 0, 4, 6, 5, 6]) nb_post = NB_Post(alpha, beta, args.r) # create hmc and mcmc object and sample hmc_kernel = HMC(nb_post.model, step_size=args.step_size, num_steps=args.num_steps) mcmc = MCMC(hmc_kernel, num_samples=args.num_samples, warmup_steps=args.warm_steps) # sample the posterior mcmc.run(data, args.logit) if args.logit: param = 'eta' posterior_samples = mcmc.get_samples()[param] # logit transform posterior_samples = torch.exp(posterior_samples) / (1. + torch.exp(posterior_samples)) # plot the estimated posterior density plot_logit_density(posterior_samples) else: param = 'p' posterior_samples = mcmc.get_samples()[param] poster_alpha = (alpha + data.sum()).numpy() poster_beta = (len(data) * r + beta).numpy() # plot the estimated and ground truth density plot_density(poster_alpha, poster_beta, posterior_samples)