def test_arrowhead_mass(): def model(prec): w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1)) x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1)) y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1)) z = pyro.sample("z", dist.Normal(0, 1000).expand([2]).to_event(1)) wyxz = torch.cat([w, y, x, z]) pyro.sample("obs", dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), obs=wyxz) A = torch.randn(6, 12) prec = A @ A.t() * 0.1 # smoke tests for dense_mass in [True, False]: kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass) mcmc = MCMC(kernel, num_samples=1, warmup_steps=1) mcmc.run(prec) assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass) kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("y", "x")]) kernel.mass_matrix_adapter = ArrowheadMassMatrix() mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000) mcmc.run(prec) assert ("w", "y", "x", "z") in kernel.inverse_mass_matrix mass_matrix = kernel.mass_matrix_adapter.mass_matrix[("w", "y", "x", "z")] assert mass_matrix.top.shape == (4, 6) assert mass_matrix.bottom_diag.shape == (2,) assert_close(mass_matrix.top, prec[:4], atol=0.2, rtol=0.2) assert_close(mass_matrix.bottom_diag, prec.diag()[4:], atol=0.2, rtol=0.2)
def test_dirichlet_categorical_grad_adapt(): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True) nuts_kernel.mass_matrix_adapter = ArrowheadMassMatrix() mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() posterior = samples["p_latent"] assert_equal(posterior.mean(0), true_probs, prec=0.02)
def fit_mcmc(self, **options): r""" Runs NUTS inference to generate posterior samples. This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples`` attribute on completion. This uses an asymptotically exact enumeration-based model when ``num_quant_bins > 1``, and a cheaper moment-matched approximate model when ``num_quant_bins == 1``. :param \*\*options: Options passed to :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are pulled out and have special meaning. :param int num_samples: Number of posterior samples to draw via mcmc. Defaults to 100. :param int max_tree_depth: (Default 5). Max tree depth of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. :param full_mass: Specification of mass matrix of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass over global random variables. :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head of an arrowhead matrix versus simply as a block. Defaults to False. :param int num_quant_bins: If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuous-valued relaxed approximate inference. Note that computational cost is exponential in `num_quant_bins`. Defaults to 1 for relaxed inference. :param bool haar: Whether to use a Haar wavelet reparameterizer. Defaults to True. :param int haar_full_mass: Number of low frequency Haar components to include in the full mass matrix. If ``haar=False`` then this is ignored. Defaults to 10. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ _require_double_precision() # Parse options, saving some for use in .predict(). num_samples = options.setdefault("num_samples", 100) num_chains = options.setdefault("num_chains", 1) self.num_quant_bins = options.pop("num_quant_bins", 1) assert isinstance(self.num_quant_bins, int) assert self.num_quant_bins >= 1 self.relaxed = self.num_quant_bins == 1 # Setup Haar wavelet transform. haar = options.pop("haar", False) haar_full_mass = options.pop("haar_full_mass", 10) full_mass = options.pop("full_mass", self.full_mass) assert isinstance(haar, bool) assert isinstance(haar_full_mass, int) and haar_full_mass >= 0 assert isinstance(full_mass, (bool, list)) haar_full_mass = min(haar_full_mass, self.duration) if not haar: haar_full_mass = 0 if full_mass is True: haar_full_mass = 0 # No need to split. elif haar_full_mass >= self.duration: full_mass = True # Effectively full mass. haar_full_mass = 0 if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports) if haar_full_mass: assert full_mass and isinstance(full_mass, list) full_mass = full_mass[:] full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims)) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} init_strategy = init_to_generated( generate=functools.partial(self._heuristic, haar, **heuristic_options)) # Configure a kernel. logger.info("Running inference...") model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) kernel = NUTS(model, full_mass=full_mass, init_strategy=init_strategy, max_plate_nesting=self.max_plate_nesting, jit_compile=options.pop("jit_compile", False), jit_options=options.pop("jit_options", None), ignore_jit_warnings=options.pop("ignore_jit_warnings", True), target_accept_prob=options.pop("target_accept_prob", 0.8), max_tree_depth=options.pop("max_tree_depth", 5)) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() # Run mcmc. options.setdefault("disable_validation", None) mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() if haar: haar.aux_to_user(self.samples) # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). model = self._relaxed_model if self.relaxed else self._quantized_model self.samples = align_samples(self.samples, model, particle_dim=-1 - self.max_plate_nesting) assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return mcmc # E.g. so user can run mcmc.summary().