def laplace_approximation(self, *args, **kwargs): """ Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ guide_trace = poutine.trace(self).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() loc = pyro.param("{}_loc".format(self.prefix)) H = hessian(loss, loc.unconstrained()) cov = H.inverse() scale_tril = cov.cholesky() # calculate scale_tril from self.guide() scale_tril_name = "{}_scale_tril".format(self.prefix) pyro.param(scale_tril_name, scale_tril, constraint=constraints.lower_cholesky) # force an update to scale_tril even if it already exists pyro.get_param_store()[scale_tril_name] = scale_tril gaussian_guide = AutoMultivariateNormal(self.model, prefix=self.prefix) gaussian_guide._setup_prototype(*args, **kwargs) return gaussian_guide
def laplace_approximation(self, *args, **kwargs): """ Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ guide_trace = poutine.trace(self).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() H = hessian(loss, self.loc) with torch.no_grad(): loc = self.loc.detach() cov = H.inverse() scale = cov.diagonal().sqrt() cov /= scale[:, None] cov /= scale[None, :] scale_tril = torch.linalg.cholesky(cov) gaussian_guide = AutoMultivariateNormal(self.model) gaussian_guide._setup_prototype(*args, **kwargs) # Set detached loc, scale, scale_tril parameters as computed above. del gaussian_guide.loc del gaussian_guide.scale del gaussian_guide.scale_tril gaussian_guide.register_buffer("loc", loc) gaussian_guide.register_buffer("scale", scale) gaussian_guide.register_buffer("scale_tril", scale_tril) return gaussian_guide
def test_hessian_mvn(): tmp = torch.randn(3, 10) cov = torch.matmul(tmp, tmp.t()) mvn = dist.MultivariateNormal(cov.new_zeros(3), cov) x = torch.randn(3, requires_grad=True) y = mvn.log_prob(x) assert_equal(hessian(y, x), -mvn.precision_matrix)
def test_hessian_multi_variables(): x = torch.randn(3, requires_grad=True) z = torch.randn(3, requires_grad=True) y = (x ** 2 * z + z ** 3).sum() H = hessian(y, (x, z)) Hxx = (2 * z).diag() Hxz = (2 * x).diag() Hzz = (6 * z).diag() target_H = torch.cat([torch.cat([Hxx, Hxz]), torch.cat([Hxz, Hzz])], dim=1) assert_equal(H, target_H)
def laplace_approximation(self, *args, **kwargs): """ Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ guide_trace = poutine.trace(self).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() H = hessian(loss, self.loc) cov = H.inverse() loc = self.loc scale_tril = cov.cholesky() gaussian_guide = AutoMultivariateNormal(self.model) gaussian_guide._setup_prototype(*args, **kwargs) # Set loc, scale_tril parameters as computed above. gaussian_guide.loc = loc gaussian_guide.scale_tril = scale_tril return gaussian_guide