def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chains): num_warmup, num_samples = 10, 10 initial_params, potential_fn, transforms, _ = initialize_model( model, num_chains=num_chains) iters = [] hook = partial(_hook, iters) mp_context = "spawn" if "CUDA_TEST" in os.environ else None kern = kernel(potential_fn=potential_fn, transforms=transforms, jit_compile=jit) samples, _ = run_mcmc_cls( data=None, kernel=kern, num_samples=num_samples, warmup_steps=num_warmup, initial_params=initial_params, hook_fn=hook, num_chains=num_chains, mp_context=mp_context, ) assert samples == {} if num_chains == 1: expected = [("Warmup", i) for i in range(num_warmup)] + [("Sample", i) for i in range(num_samples)] assert iters == expected
def test_mcmc_interface(num_draws, group_by_chain, num_chains): num_samples = 2000 data = torch.tensor([1.0]) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) mcmc = MCMC(kernel=kernel, num_samples=num_samples, warmup_steps=100, num_chains=num_chains, mp_context="spawn", initial_params=initial_params, transforms=transforms) mcmc.run(data) samples = mcmc.get_samples(num_draws, group_by_chain=group_by_chain) # test sample shape expected_samples = num_draws if num_draws is not None else num_samples if group_by_chain: expected_shape = (mcmc.num_chains, expected_samples, 1) elif num_draws is not None: # FIXME: what is the expected behavior of num_draw is not None and group_by_chain=False? expected_shape = (expected_samples, 1) else: expected_shape = (mcmc.num_chains * expected_samples, 1) assert samples['y'].shape == expected_shape # test sample stats if group_by_chain: samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} sample_mean = samples['y'].mean() sample_std = samples['y'].std() assert_close(sample_mean, torch.tensor(0.0), atol=0.05) assert_close(sample_std, torch.tensor(1.0), atol=0.05)
def test_num_chains(num_chains, cpu_count, default_init_params, monkeypatch): monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: cpu_count) data = torch.tensor([1.0]) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data, ), num_chains=num_chains) if default_init_params: initial_params = None kernel = PriorKernel(normal_normal_model) available_cpu = max(1, cpu_count - 1) mp_context = "spawn" with optional(pytest.warns(UserWarning), available_cpu < num_chains): mcmc = MCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, initial_params=initial_params, transforms=transforms, mp_context=mp_context, ) mcmc.run(data) assert mcmc.num_chains == num_chains if mcmc.num_chains == 1 or available_cpu < num_chains: assert isinstance(mcmc.sampler, _UnarySampler) else: assert isinstance(mcmc.sampler, _MultiSampler)
def test_predictive(num_samples, parallel): model, data, true_probs = beta_bernoulli() init_params, potential_fn, transforms, _ = initialize_model( model, model_args=(data, )) nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms) mcmc = MCMC(nuts_kernel, 100, initial_params=init_params, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() with ignore_experimental_warning(): with optional(pytest.warns(UserWarning), num_samples not in (None, 100)): predictive_samples = predictive(model, samples, num_samples=num_samples, return_sites=["beta", "obs"], parallel=parallel) # check shapes assert predictive_samples["beta"].shape == (100, 5) assert predictive_samples["obs"].shape == (100, 1000, 5) # check sample mean assert_close(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) guide(**kwargs) neutra = NeuTraReparam(guide) reparam_model = neutra.reparam(model) _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model( reparam_model, model_kwargs=kwargs ) latent_x = list(init_params.values())[0] transformed_params = neutra.transform_sample(latent_x) pe_transformed = pe_fn_neutra(init_params) neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms) latent_y = neutra_transform(latent_x) log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()}) assert_close(pe_transformed, pe - log_det_jacobian)
def setup(self, warmup_steps, data): self.data = data init_params, potential_fn, transforms, model_trace = initialize_model(self.model, model_args=(data,)) if self._initial_params is None: self._initial_params = init_params if self.transforms is None: self.transforms = transforms self._prototype_trace = model_trace
def test_potential_fn_pickling(jit): data = dist.Bernoulli(torch.tensor( [0.8, 0.2])).sample(sample_shape=(torch.Size((1000, )))) _, potential_fn, _, _ = initialize_model(_beta_bernoulli, (data, ), jit_compile=jit, skip_jit_warnings=True) test_data = {'p_latent': torch.tensor([0.2, 0.6])} assert_close( pickle.loads(pickle.dumps(potential_fn))(test_data), potential_fn(test_data))
def test_potential_fn_pickling(jit): data = dist.Bernoulli(torch.tensor([0.8, 0.2])).sample(sample_shape=(torch.Size((1000,)))) _, potential_fn, _, _ = initialize_model(_beta_bernoulli, (data,), jit_compile=jit, skip_jit_warnings=True) test_data = {'p_latent': torch.tensor([0.2, 0.6])} buffer = io.BytesIO() torch.save(potential_fn, buffer) buffer.seek(0) deser_potential_fn = torch.load(buffer) assert_close(deser_potential_fn(test_data), potential_fn(test_data))
def run(self, *args, **kwargs): self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains z_flat_acc = [[] for _ in range(self.num_chains)] with pyro.validation_enabled(not self.disable_validation): for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1 z_structure = x elif num_samples[chain_id] == self.num_samples + 1: self._diagnostics[chain_id] = x else: num_samples[chain_id] += 1 if self.num_chains > 1: x_cloned = x.clone() del x else: x_cloned = x z_flat_acc[chain_id].append(x_cloned) z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc]) # unpack latent pos = 0 z_acc = z_structure.copy() for k in sorted(z_structure): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = z_flat_acc[:, :, pos:next_pos].reshape((self.num_chains, self.num_samples) + shape) pos = next_pos assert pos == z_flat_acc.shape[-1] # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: if hasattr(self.kernel, 'transforms'): if self.kernel.transforms is not None: self.transforms = self.kernel.transforms elif self.kernel.model: _, _, self.transforms, _ = initialize_model( self.kernel.model, model_args=args, model_kwargs=kwargs) else: self.transforms = {} # transform samples back to constrained space for name, transform in self.transforms.items(): z_acc[name] = transform.inv(z_acc[name]) self._samples = z_acc # terminate the sampler (shut down worker processes) self.sampler.terminate(True)
def setup(self, warmup_steps, *args, **kwargs): """Sets up the sampler.""" self._warmup_steps = warmup_steps init_params, _, _, _ = initialize_model( self.model, args, kwargs, ) if self._initial_params is None: self._initial_params = init_params self._model_args = args self._model_kwargs = kwargs
def _initialize_model_properties(self, model_args, model_kwargs): init_params, potential_fn, transforms, trace = initialize_model( self.model, model_args, model_kwargs, transforms=self.transforms, max_plate_nesting=self._max_plate_nesting, jit_compile=self._jit_compile, jit_options=self._jit_options, skip_jit_warnings=self._ignore_jit_warnings, ) self.potential_fn = potential_fn self.transforms = transforms if self._initial_params is None: self.initial_params = init_params self._prototype_trace = trace
def test_mcmc_diagnostics(num_chains): data = torch.tensor([2.0]).repeat(3) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) mcmc = MCMC(kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn", initial_params=initial_params, transforms=transforms) mcmc.run(data) if not torch.backends.mkl.is_available(): pytest.skip() diagnostics = mcmc.diagnostics() assert diagnostics["y"]["n_eff"].shape == data.shape assert diagnostics["y"]["r_hat"].shape == data.shape assert diagnostics["dummy_key"] == {'chain {}'.format(i): 'dummy_value' for i in range(num_chains)}
def test_mcmc_diagnostics(run_mcmc_cls, num_chains): data = torch.tensor([2.0]).repeat(3) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data, ), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) if run_mcmc_cls == run_default_mcmc: mcmc = MCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, mp_context="spawn", initial_params=initial_params, transforms=transforms, ) else: mcmc = StreamingMCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, initial_params=initial_params, transforms=transforms, ) mcmc.run(data) if not torch.backends.mkl.is_available(): pytest.skip() diagnostics = mcmc.diagnostics() if run_mcmc_cls == run_default_mcmc: # TODO n_eff for streaming MCMC assert diagnostics["y"]["n_eff"].shape == data.shape assert diagnostics["y"]["r_hat"].shape == data.shape assert diagnostics["dummy_key"] == { "chain {}".format(i): "dummy_value" for i in range(num_chains) }
def infer_sample( cond_model, n_steps, warmup_steps, n_chains=1, device="cpu", guidefile=None, guide_conf=None, mcmcfile=None, ): """Runs the NUTS HMC algorithm. Saves the samples and weights as well as a netcdf file for the run. Parameters ---------- args : dict Command line arguments. cond_model : callable Model conditioned on an observed images. """ initial_params, potential_fn, transforms, prototype_trace = util.initialize_model( cond_model) if guidefile is not None: guide = init_guide(cond_model, guide_conf, guidefile=guidefile, device=device) sample = guide() for key in initial_params.keys(): initial_params[key] = transforms[key](sample[key].detach()) # FIXME: In the case of DiagonalNormal, results have to be mapped back onto unpacked latents if guide_conf["type"] == "DiagonalNormal": transform = guide.get_transform() unpack_fn = lambda u: guide.unpack_latent(u) potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) initial_params = {"z": torch.zeros(guide.get_posterior().shape())} transforms = None def fun(*args, **kwargs): res = potential_fn(*args, **kwargs) return res nuts_kernel = NUTS( potential_fn=fun, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, use_multinomial_sampling=True, jit_compile=False, max_tree_depth=10, transforms=transforms, step_size=1.0, ) nuts_kernel.initial_params = initial_params # Run mcmc = MCMC( nuts_kernel, n_steps, warmup_steps=warmup_steps, initial_params=initial_params, num_chains=n_chains, ) mcmc.run() # This block lets the posterior be pickled mcmc.sampler = None mcmc.kernel.potential_fn = None mcmc._cache = {} print(f"Saving MCMC object to {mcmcfile}") with open(mcmcfile, "wb") as f: pickle.dump(mcmc, f, pickle.HIGHEST_PROTOCOL)
def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] logging.info("Original Dataset:") logging.info(baseball_dataset) # (1) Full Pooling Model init_params, potential_fn, transforms, _ = initialize_model(fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_fully_pooled = mcmc.get_samples() logging.info("\nModel: Fully Pooled") logging.info("===================") logging.info("\nphi:") logging.info(summary(samples_fully_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset) evaluate_log_posterior_density(fully_pooled, samples_fully_pooled, baseball_dataset) # (2) No Pooling Model init_params, potential_fn, transforms, _ = initialize_model(not_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_not_pooled = mcmc.get_samples() logging.info("\nModel: Not Pooled") logging.info("=================") logging.info("\nphi:") logging.info(summary(samples_not_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset) evaluate_log_posterior_density(not_pooled, samples_not_pooled, baseball_dataset) # (3) Partially Pooled Model init_params, potential_fn, transforms, _ = initialize_model(partially_pooled, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_partially_pooled = mcmc.get_samples() logging.info("\nModel: Partially Pooled") logging.info("=======================") logging.info("\nphi:") logging.info(summary(samples_partially_pooled, sites=["phi"], player_names=player_names, diagnostics=diagnostics)["phi"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset) evaluate_log_posterior_density(partially_pooled, samples_partially_pooled, baseball_dataset) # (4) Partially Pooled with Logit Model init_params, potential_fn, transforms, _ = initialize_model(partially_pooled_with_logit, model_args=(at_bats, hits), num_chains=args.num_chains) nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) diagnostics = mcmc.diagnostics() samples_partially_pooled_logit = mcmc.get_samples() logging.info("\nModel: Partially Pooled with Logit") logging.info("==================================") logging.info("\nSigmoid(alpha):") logging.info(summary(samples_partially_pooled_logit, sites=["alpha"], player_names=player_names, transforms={"alpha": torch.sigmoid}, diagnostics=diagnostics)["alpha"]) num_divergences = sum(map(len, diagnostics["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset) evaluate_log_posterior_density(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset)
irt_model = irt_model_2pl elif args.irt_model == '3pl': if args.hierarchical: irt_model = irt_model_3pl_hierarchical else: irt_model = irt_model_3pl else: raise Exception('irt_model {} not supported.'.format(args.irt_model)) init_params, potential_fn, transforms, _ = initialize_model( irt_model, model_args=( args.ability_dim, num_person, num_item, device, response, mask, 1, ), num_chains=args.num_chains, ) start_time = time.time() nuts_kernel = NUTS(potential_fn = potential_fn) mcmc = MCMC( nuts_kernel, num_samples = args.num_samples, warmup_steps = args.num_warmup, num_chains = args.num_chains,
def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] logging.info("Original Dataset:") logging.info(baseball_dataset) # (1) Full Pooling Model # In this model, we illustrate how to use MCMC with general potential_fn. init_params, potential_fn, transforms, _ = initialize_model( fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains, jit_compile=args.jit, skip_jit_warnings=True) nuts_kernel = NUTS(potential_fn=potential_fn) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains, initial_params=init_params, transforms=transforms) mcmc.run(at_bats, hits) samples_fully_pooled = mcmc.get_samples() logging.info("\nModel: Fully Pooled") logging.info("===================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset) evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset) # (2) No Pooling Model nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_not_pooled = mcmc.get_samples() logging.info("\nModel: Not Pooled") logging.info("=================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset) evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset) # (3) Partially Pooled Model nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_partially_pooled = mcmc.get_samples() logging.info("\nModel: Partially Pooled") logging.info("=======================") logging.info("\nphi:") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["phi"], player_names=player_names, diagnostics=True, group_by_chain=True)["phi"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset) evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset) # (4) Partially Pooled with Logit Model nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(at_bats, hits) samples_partially_pooled_logit = mcmc.get_samples() logging.info("\nModel: Partially Pooled with Logit") logging.info("==================================") logging.info("\nSigmoid(alpha):") logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True), sites=["alpha"], player_names=player_names, transforms={"alpha": torch.sigmoid}, diagnostics=True, group_by_chain=True)["alpha"]) num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values())) logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences)) sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset) evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset)
def run(self, *args, **kwargs): """ Run MCMC to generate samples and populate `self._samples`. Example usage: .. code-block:: python def model(data): ... nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=500) mcmc.run(data) samples = mcmc.get_samples() :param args: optional arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. :param kwargs: optional keywords arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. """ self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains z_flat_acc = [[] for _ in range(self.num_chains)] with optional(pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None): for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1 z_structure = x elif num_samples[chain_id] == self.num_samples + 1: self._diagnostics[chain_id] = x else: num_samples[chain_id] += 1 if self.num_chains > 1: x_cloned = x.clone() del x else: x_cloned = x z_flat_acc[chain_id].append(x_cloned) z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc]) # unpack latent pos = 0 z_acc = z_structure.copy() for k in sorted(z_structure): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = z_flat_acc[:, :, pos:next_pos].reshape((self.num_chains, self.num_samples) + shape) pos = next_pos assert pos == z_flat_acc.shape[-1] # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: # Try to initialize kernel.transforms using kernel.setup(). if getattr(self.kernel, "transforms", None) is None: warmup_steps = 0 self.kernel.setup(warmup_steps, *args, **kwargs) # Use `kernel.transforms` when available if getattr(self.kernel, "transforms", None) is not None: self.transforms = self.kernel.transforms # Else, get transforms from model (e.g. in multiprocessing). elif self.kernel.model: _, _, self.transforms, _ = initialize_model( self.kernel.model, model_args=args, model_kwargs=kwargs, initial_params={}) # Assign default value else: self.transforms = {} # transform samples back to constrained space for name, transform in self.transforms.items(): z_acc[name] = transform.inv(z_acc[name]) self._samples = z_acc # terminate the sampler (shut down worker processes) self.sampler.terminate(True)