Ejemplo n.º 1
0
def test_sym_rinverse(A, use_sym):
    d = A.shape[-1]
    assert_equal(rinverse(A, sym=use_sym), torch.inverse(A), prec=1e-8)
    assert_equal(torch.mm(A, rinverse(A, sym=use_sym)),
                 torch.eye(d),
                 prec=1e-8)
    batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
    expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
    assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8)
Ejemplo n.º 2
0
def newton_step_3d(loss, x, trust_radius=None):
    """
    Performs a Newton update step to minimize loss on a batch of 3-dimensional
    variables, optionally regularizing to constrain to a trust region.

    See :func:`newton_step` for details.

    :param torch.Tensor loss: A scalar function of ``x`` to be minimized.
    :param torch.Tensor x: A dependent variable with rightmost size of 2.
    :param float trust_radius: An optional trust region trust_radius. The
        updated value ``mode`` of this function will be within
        ``trust_radius`` of the input ``x``.
    :return: A pair ``(mode, cov)`` where ``mode`` is an updated tensor
        of the same shape as the original value ``x``, and ``cov`` is an
        esitmate of the covariance 3x3 matrix with
        ``cov.shape == x.shape[:-1] + (3,3)``.
    :rtype: tuple
    """
    if loss.shape != ():
        raise ValueError(
            "Expected loss to be a scalar, actual shape {}".format(loss.shape))
    if x.dim() < 1 or x.shape[-1] != 3:
        raise ValueError(
            "Expected x to have rightmost size 3, actual shape {}".format(
                x.shape))

    # compute derivatives
    g = grad(loss, [x], create_graph=True)[0]
    H = torch.stack(
        [
            grad(g[..., 0].sum(), [x], create_graph=True)[0],
            grad(g[..., 1].sum(), [x], create_graph=True)[0],
            grad(g[..., 2].sum(), [x], create_graph=True)[0],
        ],
        -1,
    )
    assert g.shape[-1:] == (3, )
    assert H.shape[-2:] == (3, 3)
    warn_if_nan(g, "g")
    warn_if_nan(H, "H")

    if trust_radius is not None:
        # regularize to keep update within ball of given trust_radius
        # calculate eigenvalues of symmetric matrix
        min_eig, _, _ = eig_3d(H)
        regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius -
                       min_eig).clamp_(min=1e-8)
        warn_if_nan(regularizer, "regularizer")
        H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(
            3, dtype=H.dtype, device=H.device)

    # compute newton update
    Hinv = rinverse(H, sym=True)
    warn_if_nan(Hinv, "Hinv")

    # apply update
    x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1)
    assert x_new.shape == x.shape, "{} {}".format(x_new.shape, x.shape)
    return x_new, Hinv
Ejemplo n.º 3
0
def newton_step_2d(loss, x, trust_radius=None):
    """
    Performs a Newton update step to minimize loss on a batch of 2-dimensional
    variables, optionally regularizing to constrain to a trust region.

    See :func:`newton_step` for details.

    :param torch.Tensor loss: A scalar function of ``x`` to be minimized.
    :param torch.Tensor x: A dependent variable with rightmost size of 2.
    :param float trust_radius: An optional trust region trust_radius. The
        updated value ``mode`` of this function will be within
        ``trust_radius`` of the input ``x``.
    :return: A pair ``(mode, cov)`` where ``mode`` is an updated tensor
        of the same shape as the original value ``x``, and ``cov`` is an
        esitmate of the covariance 2x2 matrix with
        ``cov.shape == x.shape[:-1] + (2,2)``.
    :rtype: tuple
    """
    if loss.shape != ():
        raise ValueError(
            'Expected loss to be a scalar, actual shape {}'.format(loss.shape))
    if x.dim() < 1 or x.shape[-1] != 2:
        raise ValueError(
            'Expected x to have rightmost size 2, actual shape {}'.format(
                x.shape))

    # compute derivatives
    g = grad(loss, [x], create_graph=True)[0]
    H = torch.stack([
        grad(g[..., 0].sum(), [x], create_graph=True)[0],
        grad(g[..., 1].sum(), [x], create_graph=True)[0]
    ], -1)
    assert g.shape[-1:] == (2, )
    assert H.shape[-2:] == (2, 2)
    warn_if_nan(g, 'g')
    warn_if_nan(H, 'H')

    if trust_radius is not None:
        # regularize to keep update within ball of given trust_radius
        detH = H[..., 0, 0] * H[..., 1, 1] - H[..., 0, 1] * H[..., 1, 0]
        mean_eig = (H[..., 0, 0] + H[..., 1, 1]) / 2
        min_eig = mean_eig - (mean_eig**2 - detH).clamp(min=0).sqrt()
        regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius -
                       min_eig).clamp_(min=1e-8)
        warn_if_nan(regularizer, 'regularizer')
        H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(
            2, dtype=H.dtype, device=H.device)

    # compute newton update
    Hinv = rinverse(H, sym=True)
    warn_if_nan(Hinv, 'Hinv')

    # apply update
    x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1)
    assert x_new.shape == x.shape
    return x_new, Hinv
Ejemplo n.º 4
0
Archivo: guides.py Proyecto: zyxue/pyro
    def linear_model_formula(self, y, design, target_labels):

        tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag))
        xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag
        xtxi = rinverse(xtx, sym=True)
        mu = rmv(xtxi, rmv(design.transpose(-1, -2), y))

        # Extract sub-indices
        mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels)
        scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels}

        return mu, scale_tril
Ejemplo n.º 5
0
    def finalize(self, loss, target_labels):
        """
        Compute the Hessian of the parameters wrt ``loss``

        :param torch.Tensor loss: the output of evaluating a loss function such as
                                  `pyro.infer.Trace_ELBO().differentiable_loss` on the model, guide and design.
        :param list target_labels: list indicating the sample sites that are targets, i.e. for which information gain
                                   should be measured.
        """
        # set self.training = False
        self.eval()
        for l, mu_l in self.means.items():
            if l not in target_labels:
                continue
            hess_l = self._hessian_diag(loss, mu_l, event_shape=(self.w_sizes[l],))
            cov_l = rinverse(hess_l)
            self.scale_trils[l] = cov_l.cholesky(upper=False)