value = torch.randn(()).exp() * 10 kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value})) kernel.setup(warmup_steps=10) assert_close(value, kernel.initial_params['x'].exp()) @pytest.mark.parametrize("init_strategy", [ init_to_feasible, init_to_mean, init_to_median, init_to_sample, init_to_uniform, init_to_value, init_to_feasible(), init_to_mean(), init_to_median(num_samples=4), init_to_sample(), init_to_uniform(radius=0.1), init_to_value(values={"x": torch.tensor(3.)}), init_to_generated( generate=lambda: init_to_value(values={"x": torch.rand(())})), ], ids=str_erase_pointers) def test_init_strategy_smoke(init_strategy): def model(): pyro.sample("x", dist.LogNormal(0, 1)) kernel = NUTS(model, init_strategy=init_strategy) kernel.setup(warmup_steps=10)
def fit_mcmc(self, **options): r""" Runs NUTS inference to generate posterior samples. This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples`` attribute on completion. This uses an asymptotically exact enumeration-based model when ``num_quant_bins > 1``, and a cheaper moment-matched approximate model when ``num_quant_bins == 1``. :param \*\*options: Options passed to :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are pulled out and have special meaning. :param int num_samples: Number of posterior samples to draw via mcmc. Defaults to 100. :param int max_tree_depth: (Default 5). Max tree depth of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. :param full_mass: Specification of mass matrix of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass over global random variables. :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head of an arrowhead matrix versus simply as a block. Defaults to False. :param int num_quant_bins: If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuous-valued relaxed approximate inference. Note that computational cost is exponential in `num_quant_bins`. Defaults to 1 for relaxed inference. :param bool haar: Whether to use a Haar wavelet reparameterizer. Defaults to True. :param int haar_full_mass: Number of low frequency Haar components to include in the full mass matrix. If ``haar=False`` then this is ignored. Defaults to 10. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ _require_double_precision() # Parse options, saving some for use in .predict(). num_samples = options.setdefault("num_samples", 100) num_chains = options.setdefault("num_chains", 1) self.num_quant_bins = options.pop("num_quant_bins", 1) assert isinstance(self.num_quant_bins, int) assert self.num_quant_bins >= 1 self.relaxed = self.num_quant_bins == 1 # Setup Haar wavelet transform. haar = options.pop("haar", False) haar_full_mass = options.pop("haar_full_mass", 10) full_mass = options.pop("full_mass", self.full_mass) assert isinstance(haar, bool) assert isinstance(haar_full_mass, int) and haar_full_mass >= 0 assert isinstance(full_mass, (bool, list)) haar_full_mass = min(haar_full_mass, self.duration) if not haar: haar_full_mass = 0 if full_mass is True: haar_full_mass = 0 # No need to split. elif haar_full_mass >= self.duration: full_mass = True # Effectively full mass. haar_full_mass = 0 if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports) if haar_full_mass: assert full_mass and isinstance(full_mass, list) full_mass = full_mass[:] full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims)) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} init_strategy = init_to_generated( generate=functools.partial(self._heuristic, haar, **heuristic_options)) # Configure a kernel. logger.info("Running inference...") model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) kernel = NUTS(model, full_mass=full_mass, init_strategy=init_strategy, max_plate_nesting=self.max_plate_nesting, jit_compile=options.pop("jit_compile", False), jit_options=options.pop("jit_options", None), ignore_jit_warnings=options.pop("ignore_jit_warnings", True), target_accept_prob=options.pop("target_accept_prob", 0.8), max_tree_depth=options.pop("max_tree_depth", 5)) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() # Run mcmc. options.setdefault("disable_validation", None) mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() if haar: haar.aux_to_user(self.samples) # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). model = self._relaxed_model if self.relaxed else self._quantized_model self.samples = align_samples(self.samples, model, particle_dim=-1 - self.max_plate_nesting) assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return mcmc # E.g. so user can run mcmc.summary().