def test_beta_binomial(hyperpriors): def model(data): with pyro.plate("plate_0", data.shape[-1]): alpha = (pyro.sample("alpha", dist.HalfCauchy(1.0)) if hyperpriors else torch.tensor([1.0, 1.0])) beta = (pyro.sample("beta", dist.HalfCauchy(1.0)) if hyperpriors else torch.tensor([1.0, 1.0])) beta_binom = BetaBinomialPair() with pyro.plate("plate_1", data.shape[-2]): probs = pyro.sample("probs", beta_binom.latent(alpha, beta)) with pyro.plate("data", data.shape[0]): pyro.sample( "binomial", beta_binom.conditional(probs=probs, total_count=total_count), obs=data, ) true_probs = torch.tensor([[0.7, 0.4], [0.6, 0.4]]) total_count = torch.tensor([[1000, 600], [400, 800]]) num_samples = 80 data = dist.Binomial( total_count=total_count, probs=true_probs).sample(sample_shape=(torch.Size((10, )))) hmc_kernel = Slice(collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True) mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50) mcmc.run(data) samples = mcmc.get_samples() posterior = posterior_replay(model, samples, data, num_samples=num_samples) assert_equal(posterior["probs"].mean(0), true_probs, prec=0.05)
def test_logistic_regression(jit, num_chains): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1.0, dim + 1.0) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): coefs_mean = torch.zeros(dim) coefs = pyro.sample("beta", dist.Normal(coefs_mean, torch.ones(dim))) y = pyro.sample("y", dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC( slice_kernel, num_samples=500, warmup_steps=100, num_chains=num_chains, mp_context="fork", available_cpu=1, ) mcmc.run(data) samples = mcmc.get_samples() assert_equal(rmse(true_coefs, samples["beta"].mean(0)).item(), 0.0, prec=0.1)
def test_gaussian_mixture_model(jit): K, N = 3, 1000 def gmm(data): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.plate("num_clusters", K): cluster_means = pyro.sample( "cluster_means", dist.Normal(torch.arange(float(K)), 1.0)) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) return cluster_means true_cluster_means = torch.tensor([1.0, 5.0, 10.0]) true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample( torch.Size((N, ))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() slice_kernel = Slice(gmm, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=300, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["phi"].mean(0).sort()[0], true_mix_proportions, prec=0.05) assert_equal(samples["cluster_means"].mean(0).sort()[0], true_cluster_means, prec=0.2)
def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) p_latent = pyro.sample("p_latent", dist.Gamma(rate, concentration)) pyro.sample("obs", dist.Normal(3, p_latent), obs=data) return p_latent true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000, )))) slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["p_latent"].mean(0), true_std, prec=0.05)
def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample( sample_shape=(torch.Size((2000, )))) slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() posterior = samples["p_latent"] assert_equal(posterior.mean(0), true_probs, prec=0.02)
def test_beta_bernoulli(): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta)) pyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample( sample_shape=(torch.Size((1000, )))) slice_kernel = Slice(model) mcmc = MCMC(slice_kernel, num_samples=400, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["p_latent"].mean(0), true_probs, prec=0.02)
def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0)) with pyro.plate("data", data.shape[0]): y = pyro.sample("y", dist.Bernoulli(y_prob)) z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data) N = 2000 y_prob = torch.tensor(0.3) y = dist.Bernoulli(y_prob).sample(torch.Size((N, ))) z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2.0 * z, 1.0).sample() slice_kernel = Slice(model, max_plate_nesting=1, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=600, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["y_prob"].mean(0), y_prob, prec=0.05)
def test_gamma_beta(jit): def model(data): alpha_prior = pyro.sample("alpha", dist.Gamma(concentration=1.0, rate=1.0)) beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0, rate=1.0)) pyro.sample( "x", dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), obs=data, ) true_alpha = torch.tensor(5.0) true_beta = torch.tensor(1.0) data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000, ))) slice_kernel = Slice(model, jit_compile=jit, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=500, warmup_steps=200) mcmc.run(data) samples = mcmc.get_samples() assert_equal(samples["alpha"].mean(0), true_alpha, prec=0.08) assert_equal(samples["beta"].mean(0), true_beta, prec=0.05)
def test_slice_conjugate_gaussian( fixture, num_samples, warmup_steps, expected_means, expected_precs, mean_tol, std_tol, ): pyro.get_param_store().clear() slice_kernel = Slice(fixture.model) mcmc = MCMC(slice_kernel, num_samples, warmup_steps, num_chains=3) mcmc.run(fixture.data) samples = mcmc.get_samples() for i in range(1, fixture.chain_len + 1): param_name = "loc_" + str(i) latent = samples[param_name] latent_loc = latent.mean(0) latent_std = latent.std(0) expected_mean = torch.ones(fixture.dim) * expected_means[i - 1] expected_std = 1 / torch.sqrt( torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents logger.debug("Posterior mean (actual) - {}".format(param_name)) logger.debug(latent_loc) logger.debug("Posterior mean (expected) - {}".format(param_name)) logger.debug(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents logger.debug("Posterior std (actual) - {}".format(param_name)) logger.debug(latent_std) logger.debug("Posterior std (expected) - {}".format(param_name)) logger.debug(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
def test_gamma_poisson(hyperpriors): def model(data): with pyro.plate("latent_dim", data.shape[1]): alpha = (pyro.sample("alpha", dist.HalfCauchy(1.0)) if hyperpriors else torch.tensor([1.0, 1.0])) beta = (pyro.sample("beta", dist.HalfCauchy(1.0)) if hyperpriors else torch.tensor([1.0, 1.0])) gamma_poisson = GammaPoissonPair() rate = pyro.sample("rate", gamma_poisson.latent(alpha, beta)) with pyro.plate("data", data.shape[0]): pyro.sample("obs", gamma_poisson.conditional(rate), obs=data) true_rate = torch.tensor([3.0, 10.0]) num_samples = 100 data = dist.Poisson(rate=true_rate).sample( sample_shape=(torch.Size((100, )))) slice_kernel = Slice(collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True) mcmc = MCMC(slice_kernel, num_samples=num_samples, warmup_steps=50) mcmc.run(data) samples = mcmc.get_samples() posterior = posterior_replay(model, samples, data, num_samples=num_samples) assert_equal(posterior["rate"].mean(0), true_rate, prec=0.3)
def run( task: Task, num_samples: int, num_observation: Optional[int] = None, observation: Optional[torch.Tensor] = None, num_chains: int = 10, num_warmup: int = 10000, kernel: str = "slice", kernel_parameters: Optional[Dict[str, Any]] = None, thinning: int = 1, diagnostics: bool = True, available_cpu: int = 1, mp_context: str = "fork", jit_compile: bool = False, automatic_transforms_enabled: bool = True, initial_params: Optional[torch.Tensor] = None, **kwargs: Any, ) -> torch.Tensor: """Runs MCMC using Pyro on potential function Produces `num_samples` while accounting for warmup (burn-in) and thinning. Note that the actual number of simulations is not controlled for with MCMC since algorithms are only used as a reference method in the benchmark. MCMC is run on the potential function, which returns the unnormalized negative log posterior probability. Note that this requires a tractable likelihood. Pyro is used to automatically construct the potential function. Args: task: Task instance num_samples: Number of samples to generate from posterior num_observation: Observation number to load, alternative to `observation` observation: Observation, alternative to `num_observation` num_chains: Number of chains num_warmup: Warmup steps, during which parameters of the sampler are adapted. Warmup samples are not returned by the algorithm. kernel: HMC, NUTS, or Slice kernel_parameters: Parameters passed to kernel thinning: Amount of thinning to apply, in order to avoid drawing correlated samples from the chain diagnostics: Flag for diagnostics available_cpu: Number of CPUs used to parallelize chains mp_context: multiprocessing context, only fork might work jit_compile: Just-in-time (JIT) compilation, can yield significant speed ups automatic_transforms_enabled: Whether or not to use automatic transforms initial_params: Parameters to initialize at Returns: Samples from posterior """ assert not (num_observation is None and observation is None) assert not (num_observation is not None and observation is not None) tic = time.time() log = sbibm.get_logger(__name__) hook_fn = None if diagnostics: log.info(f"MCMC sampling for observation {num_observation}") tb_writer, tb_close = tb_make_writer( logger=log, basepath= f"tensorboard/pyro_{kernel.lower()}/observation_{num_observation}", ) hook_fn = tb_make_hook_fn(tb_writer) if "num_simulations" in kwargs: warnings.warn( "`num_simulations` was passed as a keyword but will be ignored, see docstring for more info." ) # Prepare model and transforms conditioned_model = task._get_pyro_model(num_observation=num_observation, observation=observation) transforms = task._get_transforms( num_observation=num_observation, observation=observation, automatic_transforms_enabled=automatic_transforms_enabled, ) kernel_parameters = kernel_parameters if kernel_parameters is not None else {} kernel_parameters["jit_compile"] = jit_compile kernel_parameters["transforms"] = transforms log.info("Using kernel: {name}({parameters})".format( name=kernel, parameters=",".join([f"{k}={v}" for k, v in kernel_parameters.items()]), )) if kernel.lower() == "nuts": mcmc_kernel = NUTS(model=conditioned_model, **kernel_parameters) elif kernel.lower() == "hmc": mcmc_kernel = HMC(model=conditioned_model, **kernel_parameters) elif kernel.lower() == "slice": mcmc_kernel = Slice(model=conditioned_model, **kernel_parameters) else: raise NotImplementedError if initial_params is not None: site_name = "parameters" initial_params = {site_name: transforms[site_name](initial_params)} else: initial_params = None mcmc_parameters = { "num_chains": num_chains, "num_samples": thinning * num_samples, "warmup_steps": num_warmup, "available_cpu": available_cpu, "initial_params": initial_params, } log.info("Calling MCMC with: MCMC({name}_kernel, {parameters})".format( name=kernel, parameters=",".join([f"{k}={v}" for k, v in mcmc_parameters.items()]), )) mcmc = MCMC(mcmc_kernel, hook_fn=hook_fn, **mcmc_parameters) mcmc.run() toc = time.time() log.info(f"Finished MCMC after {toc-tic:.3f} seconds") log.info(f"Automatic transforms {mcmc.transforms}") log.info(f"Apply thinning of {thinning}") mcmc._samples = { "parameters": mcmc._samples["parameters"][:, ::thinning, :] } num_samples_available = (mcmc._samples["parameters"].shape[0] * mcmc._samples["parameters"].shape[1]) if num_samples_available < num_samples: warnings.warn("Some samples will be included multiple times") samples = mcmc.get_samples( num_samples=num_samples, group_by_chain=False)["parameters"].squeeze() else: samples = mcmc.get_samples( group_by_chain=False)["parameters"].squeeze() idx = torch.randperm(samples.shape[0])[:num_samples] samples = samples[idx, :] assert samples.shape[0] == num_samples if diagnostics: mcmc.summary() tb_ess(tb_writer, mcmc) tb_r_hat(tb_writer, mcmc) tb_marginals(tb_writer, mcmc) tb_acf(tb_writer, mcmc) tb_posteriors(tb_writer, mcmc) tb_plot_posterior(tb_writer, samples, tag="posterior/final") tb_close() return samples
def test_gaussian_hmm(num_steps): dim = 4 def model(data): initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim))) with pyro.plate("states", dim): transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim))) emission_loc = pyro.sample( "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))) emission_scale = pyro.sample( "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))) x = None with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning) ]): for t, y in pyro.markov(enumerate(data)): x = pyro.sample( "x_{}".format(t), dist.Categorical( initialize if x is None else transition[x]), infer={"enumerate": "parallel"}, ) pyro.sample( "y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y, ) def _get_initial_trace(): guide = AutoDelta( poutine.block( model, expose_fn=lambda msg: not msg["name"].startswith("x") and not msg["name"].startswith("y"), )) elbo = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(model, guide, optim.Adam({"lr": 0.01}), elbo) for _ in range(100): svi.step(data) return poutine.trace(guide).get_trace(data) def _generate_data(): transition_probs = torch.rand(dim, dim) emissions_loc = torch.arange(dim, dtype=torch.Tensor().dtype) emissions_scale = 1.0 state = torch.tensor(1) obs = [dist.Normal(emissions_loc[state], emissions_scale).sample()] for _ in range(num_steps): state = dist.Categorical(transition_probs[state]).sample() obs.append( dist.Normal(emissions_loc[state], emissions_scale).sample()) return torch.stack(obs) data = _generate_data() slice_kernel = Slice(model, max_plate_nesting=1, jit_compile=True, ignore_jit_warnings=True) if num_steps == 30: slice_kernel.initial_trace = _get_initial_trace() mcmc = MCMC(slice_kernel, num_samples=5, warmup_steps=5) mcmc.run(data)