예제 #1
0
    value = torch.randn(()).exp() * 10
    kernel = NUTS(model,
                  init_strategy=partial(init_to_value, values={"x": value}))
    kernel.setup(warmup_steps=10)
    assert_close(value, kernel.initial_params['x'].exp())


@pytest.mark.parametrize("init_strategy", [
    init_to_feasible,
    init_to_mean,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
    init_to_feasible(),
    init_to_mean(),
    init_to_median(num_samples=4),
    init_to_sample(),
    init_to_uniform(radius=0.1),
    init_to_value(values={"x": torch.tensor(3.)}),
    init_to_generated(
        generate=lambda: init_to_value(values={"x": torch.rand(())})),
],
                         ids=str_erase_pointers)
def test_init_strategy_smoke(init_strategy):
    def model():
        pyro.sample("x", dist.LogNormal(0, 1))

    kernel = NUTS(model, init_strategy=init_strategy)
    kernel.setup(warmup_steps=10)
예제 #2
0
    def fit_mcmc(self, **options):
        r"""
        Runs NUTS inference to generate posterior samples.

        This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
        :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
        attribute on completion.

        This uses an asymptotically exact enumeration-based model when
        ``num_quant_bins > 1``, and a cheaper moment-matched approximate model
        when ``num_quant_bins == 1``.

        :param \*\*options: Options passed to
            :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
            pulled out and have special meaning.
        :param int num_samples: Number of posterior samples to draw via mcmc.
            Defaults to 100.
        :param int max_tree_depth: (Default 5). Max tree depth of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
        :param full_mass: Specification of mass matrix of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
            over global random variables.
        :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
            of an arrowhead matrix versus simply as a block. Defaults to False.
        :param int num_quant_bins: If greater than 1, use asymptotically exact
            inference via local enumeration over this many quantization bins.
            If equal to 1, use continuous-valued relaxed approximate inference.
            Note that computational cost is exponential in `num_quant_bins`.
            Defaults to 1 for relaxed inference.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
            Defaults to True.
        :param int haar_full_mass: Number of low frequency Haar components to
            include in the full mass matrix. If ``haar=False`` then this is
            ignored. Defaults to 10.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
        :rtype: ~pyro.infer.mcmc.api.MCMC
        """
        _require_double_precision()

        # Parse options, saving some for use in .predict().
        num_samples = options.setdefault("num_samples", 100)
        num_chains = options.setdefault("num_chains", 1)
        self.num_quant_bins = options.pop("num_quant_bins", 1)
        assert isinstance(self.num_quant_bins, int)
        assert self.num_quant_bins >= 1
        self.relaxed = self.num_quant_bins == 1

        # Setup Haar wavelet transform.
        haar = options.pop("haar", False)
        haar_full_mass = options.pop("haar_full_mass", 10)
        full_mass = options.pop("full_mass", self.full_mass)
        assert isinstance(haar, bool)
        assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
        assert isinstance(full_mass, (bool, list))
        haar_full_mass = min(haar_full_mass, self.duration)
        if not haar:
            haar_full_mass = 0
        if full_mass is True:
            haar_full_mass = 0  # No need to split.
        elif haar_full_mass >= self.duration:
            full_mass = True  # Effectively full mass.
            haar_full_mass = 0
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
        if haar_full_mass:
            assert full_mass and isinstance(full_mass, list)
            full_mass = full_mass[:]
            full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        init_strategy = init_to_generated(
            generate=functools.partial(self._heuristic, haar, **heuristic_options))

        # Configure a kernel.
        logger.info("Running inference...")
        model = self._relaxed_model if self.relaxed else self._quantized_model
        if haar:
            model = haar.reparam(model)
        kernel = NUTS(model,
                      full_mass=full_mass,
                      init_strategy=init_strategy,
                      max_plate_nesting=self.max_plate_nesting,
                      jit_compile=options.pop("jit_compile", False),
                      jit_options=options.pop("jit_options", None),
                      ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
                      target_accept_prob=options.pop("target_accept_prob", 0.8),
                      max_tree_depth=options.pop("max_tree_depth", 5))
        if options.pop("arrowhead_mass", False):
            kernel.mass_matrix_adapter = ArrowheadMassMatrix()

        # Run mcmc.
        options.setdefault("disable_validation", None)
        mcmc = MCMC(kernel, **options)
        mcmc.run()
        self.samples = mcmc.get_samples()
        if haar:
            haar.aux_to_user(self.samples)

        # Unsqueeze samples to align particle dim for use in poutine.condition.
        # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
        model = self._relaxed_model if self.relaxed else self._quantized_model
        self.samples = align_samples(self.samples, model,
                                     particle_dim=-1 - self.max_plate_nesting)
        assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return mcmc  # E.g. so user can run mcmc.summary().