Exemplo n.º 1
0
def test_arrowhead_mass():
    def model(prec):
        w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1))
        x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1))
        y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1))
        z = pyro.sample("z", dist.Normal(0, 1000).expand([2]).to_event(1))
        wyxz = torch.cat([w, y, x, z])
        pyro.sample("obs", dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), obs=wyxz)

    A = torch.randn(6, 12)
    prec = A @ A.t() * 0.1

    # smoke tests
    for dense_mass in [True, False]:
        kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass)
        mcmc = MCMC(kernel, num_samples=1, warmup_steps=1)
        mcmc.run(prec)
        assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass)

    kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("y", "x")])
    kernel.mass_matrix_adapter = ArrowheadMassMatrix()
    mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000)
    mcmc.run(prec)
    assert ("w", "y", "x", "z") in kernel.inverse_mass_matrix
    mass_matrix = kernel.mass_matrix_adapter.mass_matrix[("w", "y", "x", "z")]
    assert mass_matrix.top.shape == (4, 6)
    assert mass_matrix.bottom_diag.shape == (2,)
    assert_close(mass_matrix.top, prec[:4], atol=0.2, rtol=0.2)
    assert_close(mass_matrix.bottom_diag, prec.diag()[4:], atol=0.2, rtol=0.2)
Exemplo n.º 2
0
def test_dirichlet_categorical_grad_adapt():
    def model(data):
        concentration = torch.tensor([1.0, 1.0, 1.0])
        p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration))
        pyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = torch.tensor([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,))))
    nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
    nuts_kernel.mass_matrix_adapter = ArrowheadMassMatrix()
    mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = samples["p_latent"]
    assert_equal(posterior.mean(0), true_probs, prec=0.02)
Exemplo n.º 3
0
    def fit_mcmc(self, **options):
        r"""
        Runs NUTS inference to generate posterior samples.

        This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
        :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
        attribute on completion.

        This uses an asymptotically exact enumeration-based model when
        ``num_quant_bins > 1``, and a cheaper moment-matched approximate model
        when ``num_quant_bins == 1``.

        :param \*\*options: Options passed to
            :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
            pulled out and have special meaning.
        :param int num_samples: Number of posterior samples to draw via mcmc.
            Defaults to 100.
        :param int max_tree_depth: (Default 5). Max tree depth of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
        :param full_mass: Specification of mass matrix of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
            over global random variables.
        :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
            of an arrowhead matrix versus simply as a block. Defaults to False.
        :param int num_quant_bins: If greater than 1, use asymptotically exact
            inference via local enumeration over this many quantization bins.
            If equal to 1, use continuous-valued relaxed approximate inference.
            Note that computational cost is exponential in `num_quant_bins`.
            Defaults to 1 for relaxed inference.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
            Defaults to True.
        :param int haar_full_mass: Number of low frequency Haar components to
            include in the full mass matrix. If ``haar=False`` then this is
            ignored. Defaults to 10.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
        :rtype: ~pyro.infer.mcmc.api.MCMC
        """
        _require_double_precision()

        # Parse options, saving some for use in .predict().
        num_samples = options.setdefault("num_samples", 100)
        num_chains = options.setdefault("num_chains", 1)
        self.num_quant_bins = options.pop("num_quant_bins", 1)
        assert isinstance(self.num_quant_bins, int)
        assert self.num_quant_bins >= 1
        self.relaxed = self.num_quant_bins == 1

        # Setup Haar wavelet transform.
        haar = options.pop("haar", False)
        haar_full_mass = options.pop("haar_full_mass", 10)
        full_mass = options.pop("full_mass", self.full_mass)
        assert isinstance(haar, bool)
        assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
        assert isinstance(full_mass, (bool, list))
        haar_full_mass = min(haar_full_mass, self.duration)
        if not haar:
            haar_full_mass = 0
        if full_mass is True:
            haar_full_mass = 0  # No need to split.
        elif haar_full_mass >= self.duration:
            full_mass = True  # Effectively full mass.
            haar_full_mass = 0
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
        if haar_full_mass:
            assert full_mass and isinstance(full_mass, list)
            full_mass = full_mass[:]
            full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        init_strategy = init_to_generated(
            generate=functools.partial(self._heuristic, haar, **heuristic_options))

        # Configure a kernel.
        logger.info("Running inference...")
        model = self._relaxed_model if self.relaxed else self._quantized_model
        if haar:
            model = haar.reparam(model)
        kernel = NUTS(model,
                      full_mass=full_mass,
                      init_strategy=init_strategy,
                      max_plate_nesting=self.max_plate_nesting,
                      jit_compile=options.pop("jit_compile", False),
                      jit_options=options.pop("jit_options", None),
                      ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
                      target_accept_prob=options.pop("target_accept_prob", 0.8),
                      max_tree_depth=options.pop("max_tree_depth", 5))
        if options.pop("arrowhead_mass", False):
            kernel.mass_matrix_adapter = ArrowheadMassMatrix()

        # Run mcmc.
        options.setdefault("disable_validation", None)
        mcmc = MCMC(kernel, **options)
        mcmc.run()
        self.samples = mcmc.get_samples()
        if haar:
            haar.aux_to_user(self.samples)

        # Unsqueeze samples to align particle dim for use in poutine.condition.
        # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
        model = self._relaxed_model if self.relaxed else self._quantized_model
        self.samples = align_samples(self.samples, model,
                                     particle_dim=-1 - self.max_plate_nesting)
        assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return mcmc  # E.g. so user can run mcmc.summary().