def test_sym_rinverse(A, use_sym): d = A.shape[-1] assert_equal(rinverse(A, sym=use_sym), torch.inverse(A), prec=1e-8) assert_equal(torch.mm(A, rinverse(A, sym=use_sym)), torch.eye(d), prec=1e-8) batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8)
def newton_step_3d(loss, x, trust_radius=None): """ Performs a Newton update step to minimize loss on a batch of 3-dimensional variables, optionally regularizing to constrain to a trust region. See :func:`newton_step` for details. :param torch.Tensor loss: A scalar function of ``x`` to be minimized. :param torch.Tensor x: A dependent variable with rightmost size of 2. :param float trust_radius: An optional trust region trust_radius. The updated value ``mode`` of this function will be within ``trust_radius`` of the input ``x``. :return: A pair ``(mode, cov)`` where ``mode`` is an updated tensor of the same shape as the original value ``x``, and ``cov`` is an esitmate of the covariance 3x3 matrix with ``cov.shape == x.shape[:-1] + (3,3)``. :rtype: tuple """ if loss.shape != (): raise ValueError( "Expected loss to be a scalar, actual shape {}".format(loss.shape)) if x.dim() < 1 or x.shape[-1] != 3: raise ValueError( "Expected x to have rightmost size 3, actual shape {}".format( x.shape)) # compute derivatives g = grad(loss, [x], create_graph=True)[0] H = torch.stack( [ grad(g[..., 0].sum(), [x], create_graph=True)[0], grad(g[..., 1].sum(), [x], create_graph=True)[0], grad(g[..., 2].sum(), [x], create_graph=True)[0], ], -1, ) assert g.shape[-1:] == (3, ) assert H.shape[-2:] == (3, 3) warn_if_nan(g, "g") warn_if_nan(H, "H") if trust_radius is not None: # regularize to keep update within ball of given trust_radius # calculate eigenvalues of symmetric matrix min_eig, _, _ = eig_3d(H) regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, "regularizer") H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye( 3, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) warn_if_nan(Hinv, "Hinv") # apply update x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1) assert x_new.shape == x.shape, "{} {}".format(x_new.shape, x.shape) return x_new, Hinv
def newton_step_2d(loss, x, trust_radius=None): """ Performs a Newton update step to minimize loss on a batch of 2-dimensional variables, optionally regularizing to constrain to a trust region. See :func:`newton_step` for details. :param torch.Tensor loss: A scalar function of ``x`` to be minimized. :param torch.Tensor x: A dependent variable with rightmost size of 2. :param float trust_radius: An optional trust region trust_radius. The updated value ``mode`` of this function will be within ``trust_radius`` of the input ``x``. :return: A pair ``(mode, cov)`` where ``mode`` is an updated tensor of the same shape as the original value ``x``, and ``cov`` is an esitmate of the covariance 2x2 matrix with ``cov.shape == x.shape[:-1] + (2,2)``. :rtype: tuple """ if loss.shape != (): raise ValueError( 'Expected loss to be a scalar, actual shape {}'.format(loss.shape)) if x.dim() < 1 or x.shape[-1] != 2: raise ValueError( 'Expected x to have rightmost size 2, actual shape {}'.format( x.shape)) # compute derivatives g = grad(loss, [x], create_graph=True)[0] H = torch.stack([ grad(g[..., 0].sum(), [x], create_graph=True)[0], grad(g[..., 1].sum(), [x], create_graph=True)[0] ], -1) assert g.shape[-1:] == (2, ) assert H.shape[-2:] == (2, 2) warn_if_nan(g, 'g') warn_if_nan(H, 'H') if trust_radius is not None: # regularize to keep update within ball of given trust_radius detH = H[..., 0, 0] * H[..., 1, 1] - H[..., 0, 1] * H[..., 1, 0] mean_eig = (H[..., 0, 0] + H[..., 1, 1]) / 2 min_eig = mean_eig - (mean_eig**2 - detH).clamp(min=0).sqrt() regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, 'regularizer') H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye( 2, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) warn_if_nan(Hinv, 'Hinv') # apply update x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1) assert x_new.shape == x.shape return x_new, Hinv
def linear_model_formula(self, y, design, target_labels): tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag)) xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag xtxi = rinverse(xtx, sym=True) mu = rmv(xtxi, rmv(design.transpose(-1, -2), y)) # Extract sub-indices mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels) scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def finalize(self, loss, target_labels): """ Compute the Hessian of the parameters wrt ``loss`` :param torch.Tensor loss: the output of evaluating a loss function such as `pyro.infer.Trace_ELBO().differentiable_loss` on the model, guide and design. :param list target_labels: list indicating the sample sites that are targets, i.e. for which information gain should be measured. """ # set self.training = False self.eval() for l, mu_l in self.means.items(): if l not in target_labels: continue hess_l = self._hessian_diag(loss, mu_l, event_shape=(self.w_sizes[l],)) cov_l = rinverse(hess_l) self.scale_trils[l] = cov_l.cholesky(upper=False)