Example #1
0
    def laplace_approximation(self, *args, **kwargs):
        """
        Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and
        `scale_tril` are given by Laplace approximation.
        """
        guide_trace = poutine.trace(self).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
        loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum()

        loc = pyro.param("{}_loc".format(self.prefix))
        H = hessian(loss, loc.unconstrained())
        cov = H.inverse()
        scale_tril = cov.cholesky()

        # calculate scale_tril from self.guide()
        scale_tril_name = "{}_scale_tril".format(self.prefix)
        pyro.param(scale_tril_name, scale_tril,
                   constraint=constraints.lower_cholesky)
        # force an update to scale_tril even if it already exists
        pyro.get_param_store()[scale_tril_name] = scale_tril

        gaussian_guide = AutoMultivariateNormal(self.model, prefix=self.prefix)
        gaussian_guide._setup_prototype(*args, **kwargs)
        return gaussian_guide
Example #2
0
    def laplace_approximation(self, *args, **kwargs):
        """
        Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and
        `scale_tril` are given by Laplace approximation.
        """
        guide_trace = poutine.trace(self).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.model,
                           trace=guide_trace)).get_trace(*args, **kwargs)
        loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum()

        H = hessian(loss, self.loc)
        with torch.no_grad():
            loc = self.loc.detach()
            cov = H.inverse()
            scale = cov.diagonal().sqrt()
            cov /= scale[:, None]
            cov /= scale[None, :]
            scale_tril = torch.linalg.cholesky(cov)

        gaussian_guide = AutoMultivariateNormal(self.model)
        gaussian_guide._setup_prototype(*args, **kwargs)
        # Set detached loc, scale, scale_tril parameters as computed above.
        del gaussian_guide.loc
        del gaussian_guide.scale
        del gaussian_guide.scale_tril
        gaussian_guide.register_buffer("loc", loc)
        gaussian_guide.register_buffer("scale", scale)
        gaussian_guide.register_buffer("scale_tril", scale_tril)
        return gaussian_guide
Example #3
0
def test_hessian_mvn():
    tmp = torch.randn(3, 10)
    cov = torch.matmul(tmp, tmp.t())
    mvn = dist.MultivariateNormal(cov.new_zeros(3), cov)

    x = torch.randn(3, requires_grad=True)
    y = mvn.log_prob(x)
    assert_equal(hessian(y, x), -mvn.precision_matrix)
Example #4
0
def test_hessian_multi_variables():
    x = torch.randn(3, requires_grad=True)
    z = torch.randn(3, requires_grad=True)
    y = (x ** 2 * z + z ** 3).sum()

    H = hessian(y, (x, z))
    Hxx = (2 * z).diag()
    Hxz = (2 * x).diag()
    Hzz = (6 * z).diag()
    target_H = torch.cat([torch.cat([Hxx, Hxz]), torch.cat([Hxz, Hzz])], dim=1)
    assert_equal(H, target_H)
Example #5
0
    def laplace_approximation(self, *args, **kwargs):
        """
        Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and
        `scale_tril` are given by Laplace approximation.
        """
        guide_trace = poutine.trace(self).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
        loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum()

        H = hessian(loss, self.loc)
        cov = H.inverse()
        loc = self.loc
        scale_tril = cov.cholesky()

        gaussian_guide = AutoMultivariateNormal(self.model)
        gaussian_guide._setup_prototype(*args, **kwargs)
        # Set loc, scale_tril parameters as computed above.
        gaussian_guide.loc = loc
        gaussian_guide.scale_tril = scale_tril
        return gaussian_guide