Exemplo n.º 1
0
    def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor:
        r"""Wrapper to perform (batched) cholesky inverse"""
        # TODO: get rid of this once cholesky_inverse supports batch mode
        batch_eye = torch.eye(mat_chol.shape[-1], **self.tkwargs)

        if len(mat_chol.shape) == 2:
            mat_inv = torch.cholesky_inverse(mat_chol)
        elif len(mat_chol.shape) > 2 and (mat_chol.shape[-1]
                                          == mat_chol.shape[-2]):
            batch_eye = batch_eye.repeat(*(mat_chol.shape[:-2]), 1, 1)
            chol_inv = torch.triangular_solve(batch_eye, mat_chol,
                                              upper=False).solution
            mat_inv = chol_inv.transpose(-1, -2) @ chol_inv

        return mat_inv
Exemplo n.º 2
0
    def inverse_no_cache(self, inputs, full_jacobian=False):
        """Cost:
            output = O(D^2N)
            logabsdet = O(D)
        where:
            D = num of features
            N = num of inputs
        """
        lower, upper = self._create_lower_upper()
        outputs = inputs - self.bias
        outputs, _ = torch.triangular_solve(outputs.t(),
                                            lower,
                                            upper=False,
                                            unitriangular=True)
        outputs, _ = torch.triangular_solve(outputs,
                                            upper,
                                            upper=True,
                                            unitriangular=False)
        outputs = outputs.t()

        logabsdet = -self.logabsdet()
        logabsdet = logabsdet * inputs.new_ones(outputs.shape[0])

        return outputs, logabsdet
Exemplo n.º 3
0
    def forward(self,
                x,
                inducing_points,
                inducing_values,
                variational_inducing_covar=None):
        # Compute full prior distribution
        full_inputs = torch.cat([inducing_points, x], dim=-2)
        full_output = self.model.forward(full_inputs)
        full_covar = full_output.lazy_covariance_matrix

        # Covariance terms
        num_induc = inducing_points.size(-2)
        test_mean = full_output.mean[..., num_induc:]
        induc_induc_covar = full_covar[
            ..., :num_induc, :num_induc].add_jitter()
        induc_data_covar = full_covar[..., :num_induc, num_induc:].evaluate()
        data_data_covar = full_covar[..., num_induc:, num_induc:]

        # Compute interpolation terms
        # K_ZZ^{-1/2} K_ZX
        # K_ZZ^{-1/2} \mu_Z
        L = self._cholesky_factor(induc_induc_covar)
        interp_term = torch.triangular_solve(induc_data_covar.double(),
                                             L,
                                             upper=False)[0].to(
                                                 full_inputs.dtype)

        # Compute the mean of q(f)
        # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
        predictive_mean = (torch.matmul(
            interp_term.transpose(-1, -2),
            (inducing_values -
             self.prior_distribution.mean).unsqueeze(-1)).squeeze(-1) +
                           test_mean)

        # Compute the covariance of q(f)
        # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
        middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
        if variational_inducing_covar is not None:
            middle_term = SumLazyTensor(variational_inducing_covar,
                                        middle_term)
        predictive_covar = SumLazyTensor(
            data_data_covar.add_jitter(1e-4),
            MatmulLazyTensor(interp_term.transpose(-1, -2),
                             middle_term @ interp_term))

        # Return the distribution
        return MultivariateNormal(predictive_mean, predictive_covar)
    def _inducing_inv_root(self):
        if not self.training and hasattr(self, "_cached_kernel_inv_root"):
            return self._cached_kernel_inv_root
        else:
            chol = psd_safe_cholesky(self._inducing_mat,
                                     upper=True,
                                     jitter=settings.cholesky_jitter.value())
            eye = torch.eye(chol.size(-1),
                            device=chol.device,
                            dtype=chol.dtype)
            inv_root = torch.triangular_solve(eye, chol)[0]

            res = inv_root
            if not self.training:
                self._cached_kernel_inv_root = res
            return res
Exemplo n.º 5
0
def expand_cholesky_upper(U, B, C):
    """
       U : upper cholesky factor of a positive-definite matrix K (n,n)
       B : (n,m)
       C : (m,m)
       
       Assumes that [K b, b^T c] will be positive-definite
       
       returns : upper cholesky factor of the matrix
           K b^T
           b c
    """
    S11 = U
    S21 = torch.triangular_solve(B, S11, upper=True, transpose=True)[0]  #(n,m)
    S22 = torch.cholesky(C - torch.matmul(S21.t(), S21), upper=True)
    return expand_upper(S11, S21, S22)
Exemplo n.º 6
0
        def run_test(n, k, upper, unitriangular, transpose):
            triangle_function = torch.triu if upper else torch.tril
            A = make_tensor((n, n), dtype=dtype, device=device)
            A = triangle_function(A)
            A_sparse = A.to_sparse_csr()
            B = make_tensor((n, k), dtype=dtype, device=device)

            expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose)
            expected_X = expected.solution

            actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose)
            actual_X = actual.solution
            actual_A_clone = actual.cloned_coefficient
            self.assertTrue(actual_A_clone.numel() == 0)
            self.assertEqual(actual_X, expected_X)

            # test out with C contiguous strides
            out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device)
            torch.triangular_solve(
                B, A_sparse,
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
            )
            self.assertEqual(out, expected_X)

            # test out with F contiguous strides
            out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
            torch.triangular_solve(
                B, A_sparse,
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
            )
            self.assertEqual(out, expected_X)
            self.assertEqual(out.stride(), (1, n))

            # test out with discontiguous strides
            out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2]
            if n > 0 and k > 0:
                self.assertFalse(out.is_contiguous())
                self.assertFalse(out.t().is_contiguous())
            before_stride = out.stride()
            torch.triangular_solve(
                B, A_sparse,
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
            )
            self.assertEqual(out, expected_X)
            self.assertEqual(out.stride(), before_stride)
Exemplo n.º 7
0
def _kl_multivariatenormal_multivariatenormal(p, q):
    # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
    if p.event_shape != q.event_shape:
        raise ValueError("KL-divergence between two Multivariate Normals with\
                          different event shapes cannot be computed")

    half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
                  p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
    combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
                                                p._unbroadcasted_scale_tril.shape[:-2])
    n = p.event_shape[0]
    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
    p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
    term2 = _batch_trace_XXT(torch.triangular_solve(p_scale_tril, q_scale_tril, upper=False)[0])
    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
    return half_term1 + 0.5 * (term2 + term3 - n)
Exemplo n.º 8
0
def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
    r"""Solves a system of equations with a triangular coefficient matrix :math:`A`
    and multiple right-hand sides :attr:`b`.

    In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
    with the default keyword arguments.

    For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.

    .. warning::
        :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
        removed in the next release. Please use :func:`torch.triangular_solve` instead.
    """
    warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
                  "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
    return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
Exemplo n.º 9
0
def uv_to_raydir(uv_grid, projection_matrix):
    # make coordinates homogeneous
    uvw_grid = torch.cat(
        [uv_grid,
         torch.ones(*uv_grid.shape[:2], 1, dtype=torch.double)],
        dim=-1)

    M = projection_matrix[:, :3]
    q, r = torch.qr(M.transpose(0, 1))
    r_sign = r.diag().prod().sign()

    back_sub = torch.triangular_solve(uvw_grid.unsqueeze(-1).cuda(),
                                      r,
                                      transpose=True)[0]  # (u, v, 3, 1)
    raydirs = r_sign * torch.matmul(q, back_sub)[..., 0]
    norm_raydirs = raydirs / torch.norm(raydirs, dim=-1, keepdim=True)
    return norm_raydirs
Exemplo n.º 10
0
def quadratic_mean_lsq(X, y):
    """
        Fit sum((x-c)**2/l**2)
    """
    d = X.shape[1]
    A = torch.cat([torch.ones_like(y), X, X**2], dim=1)
    Q, R = torch.qr(A)
    coefs = torch.triangular_solve(torch.matmul(Q.transpose(1, 0), y),
                                   R,
                                   upper=True)[0].flatten()
    a = coefs[d + 1:]
    b = coefs[1:d + 1]
    c = coefs[0]
    lengthscales = torch.sqrt(-1.0 / (2 * a))
    center = b * lengthscales**2
    constant = c + 0.5 * torch.sum(center**2 / (lengthscales**2))
    return lengthscales, center, constant
Exemplo n.º 11
0
def integral_vector(X, theta, l, mu, cov):
    """
        X : (n,d) tensor
        theta : 0d tensor or float
        l : (d,) tensor
        mu : (d,) tensor
        cov : (d,d) tensor
        outputs (n,) tensor
    """
    C = cov + torch.diag(l**2)
    L = torch.cholesky(C, upper=False)
    Xm = X - mu  #nxd#
    LX = torch.triangular_solve(Xm.transpose(1, 0), L, upper=False)[0]  #d x n
    expoent = -0.5 * torch.sum(LX**2, dim=0)  #(n,)
    det = torch.prod(1 / l**2) * torch.prod(torch.diag(L))**2  #|I + A^-1B|
    vec = theta / torch.sqrt(det) * torch.exp(expoent)  #(n,)
    return vec
Exemplo n.º 12
0
def _batch_mahalanobis(bL, bx):
    r"""
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
    shape, but `bL` one should be able to broadcasted to `bx` one.
    """
    n = bx.size(-1)
    bx_batch_shape = bx.shape[:-1]

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
    bx_batch_dims = len(bx_batch_shape)
    bL_batch_dims = bL.dim() - 2
    outer_batch_dims = bx_batch_dims - bL_batch_dims
    old_batch_dims = outer_batch_dims + bL_batch_dims
    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = bx.shape[:outer_batch_dims]
    for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (n, )
    bx = bx.reshape(bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (list(range(outer_batch_dims)) +
                    list(range(outer_batch_dims, new_batch_dims, 2)) +
                    list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
                    [new_batch_dims])
    bx = bx.permute(permute_dims)

    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
    M_swap = torch.triangular_solve(
        flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2)  # shape = b x c
    M = M_swap.t()  # shape = c x b

    # Now we revert the above reshape and permute operators.
    permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
    permute_inv_dims = list(range(outer_batch_dims))
    for i in range(bL_batch_dims):
        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
    reshaped_M = permuted_M.permute(
        permute_inv_dims)  # shape = (..., 1, i, j, 1)
    return reshaped_M.reshape(bx_batch_shape)
Exemplo n.º 13
0
def estimate_smallest_singular_value(U) -> Tuple[Tensor, Tensor]:
    """Given upper triangular matrix ``U`` estimate the smallest singular
    value and the correspondent right singular vector in O(n**2) operations.

    A vector `e` with components selected from {+1, -1}
    is selected so that the solution `w` to the system
    `U.T w = e` is as large as possible. Implementation
    based on algorithm 3.5.1, p. 142, from reference [1]_
    adapted for lower triangular matrix.

    References
    ----------
    .. [1] G.H. Golub, C.F. Van Loan. "Matrix computations".
           Forth Edition. JHU press. pp. 140-142.
    """

    U = torch.atleast_2d(U)
    UT = U.T
    m, n = U.shape
    if m != n:
        raise ValueError("A square triangular matrix should be provided.")

    p = torch.zeros(n, dtype=U.dtype, device=U.device)
    w = torch.empty(n, dtype=U.dtype, device=U.device)

    for k in range(n):
        wp = (1 - p[k]) / UT[k, k]
        wm = (-1 - p[k]) / UT[k, k]
        pp = p[k + 1:] + UT[k + 1:, k] * wp
        pm = p[k + 1:] + UT[k + 1:, k] * wm

        if wp.abs() + norm(pp, 1) >= wm.abs() + norm(pm, 1):
            w[k] = wp
            p[k + 1:] = pp
        else:
            w[k] = wm
            p[k + 1:] = pm

    # The system `U v = w` is solved using backward substitution.
    v = torch.triangular_solve(w.view(-1, 1), U)[0].view(-1)
    v_norm = norm(v)

    s_min = norm(w) / v_norm  # Smallest singular value
    z_min = v / v_norm  # Associated vector

    return s_min, z_min
Exemplo n.º 14
0
 def memoized_cholesky_decomposition(self):
     "memoize the cholesky decomposition to avoid recomputing it when we are not training or when we are fitting it"
     if (self._L_train_train is None) or (self.training):
         # covariance between training samples
         cov_train_train = self.kernel.matrix(
             (self.train_input_cat, self.train_input_cont),
             (self.train_input_cat, self.train_input_cont))
         # cholesky decompositions (accelerate solving of linear systems)
         self._L_train_train = psd_safe_cholesky(cov_train_train).detach(
         )  # we drop the gradient for the cholesky decomposition
         # outputs for the training data with prior correction
         train_outputs = self.train_outputs - self.prior(
             self.train_input_cat, self.train_input_cont)
         # weights for the predicted mean
         self._output_weights, _ = torch.triangular_solve(
             train_outputs, self._L_train_train, upper=False)
     return self._L_train_train, self._output_weights
Exemplo n.º 15
0
    def KL(self):
        Lu_expand = self.Lu.expand([self.num_outputs, -1, -1])

        KL = -0.5 * self.num_outputs * self.num_inducing
        KL -= 0.5 * torch.sum(
            torch.log(
                torch.diagonal(self.variational_covar, dim1=-2, dim2=-1)**2))

        KL += torch.sum(torch.log(torch.diag(self.Lu))) * self.num_outputs
        KL += 0.5 * torch.sum(
            torch.pow(
                torch.triangular_solve(
                    self.variational_covar, Lu_expand, upper=False)[0], 2))
        Kinv_m = torch.cholesky_solve(self.variational_mean, self.Lu)
        KL += 0.5 * torch.sum(self.variational_mean * Kinv_m)

        return KL
Exemplo n.º 16
0
 def negative_log_likelihood(self, hyper=None):
     if hyper is not None:
         # Record original params
         param_original = self.model.param_to_vec()
         # Update with new params
         self.cholesky_update(hyper)
     # Cholesky decomposition of mean vector
     mean_vec_sol = torch.triangular_solve(self.mean_vec,
                                           self.cholesky.float(),
                                           upper=False)[0]
     # Negative log likelihood
     nll = 0.5 * torch.sum(mean_vec_sol**2) + torch.sum(
         torch.log(torch.diag(self.cholesky))
     ) + 0.5 * self.train_y.size(0) * np.log(2 * np.pi)
     if hyper is not None:
         # Put original params back
         self.cholesky_update(param_original)
     return nll
def update_precond_dense(Q, dxs, dgs, step=0.01):
    """
    update dense preconditioner P = Q^T*Q
    Q: Cholesky factor of preconditioner with positive diagonal entries
    dxs: list of perturbations of parameters
    dgs: list of perturbations of gradients
    step: normalized step size in [0, 1]
    """
    dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs])
    dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs])

    a = Q.mm(dg)
    b = torch.triangular_solve(dx, Q.t(), upper=False)[0]

    grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))
    step0 = step / (grad.abs().max() + _tiny)

    return Q - step0 * grad.mm(Q)
Exemplo n.º 18
0
 def inverse_no_cache(self, inputs):
     """Cost:
         output = O(D^2N + KDN)
         logabsdet = O(D)
     where:
         K = num of householder transforms
         D = num of features
         N = num of inputs
     """
     upper = self._create_upper()
     outputs = inputs - self.bias
     outputs, _ = self.orthogonal.inverse(
         outputs)  # Ignore logabsdet since we know it's zero.
     outputs, _ = torch.triangular_solve(outputs.t(), upper, upper=True)
     outputs = outputs.t()
     logabsdet = -self.logabsdet()
     logabsdet = logabsdet * torch.ones(outputs.shape[0])
     return outputs, logabsdet
 def inv_matmul(self,
                right_tensor: Tensor,
                left_tensor: Optional[Tensor] = None) -> Tensor:
     if isinstance(self._tensor, NonLazyTensor):
         res = torch.triangular_solve(right_tensor,
                                      self.evaluate(),
                                      upper=self.upper).solution
     elif isinstance(self._tensor, BatchRepeatLazyTensor):
         res = self._tensor.base_lazy_tensor.inv_matmul(
             right_tensor, left_tensor)
         # TODO: Proper broadcasting
         res = res.expand(self._tensor.batch_repeat + res.shape[-2:])
     else:
         # TODO: Can we be smarter here?
         res = self._tensor.inv_matmul(right_tensor=right_tensor,
                                       left_tensor=left_tensor)
     if left_tensor is not None:
         res = left_tensor @ res
     return res
Exemplo n.º 20
0
def _update_precond_norm_dense(ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):
    # type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]
    """
    update (normalization, dense) Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where
    dX and dG have shape (M, N)
    ql has shape (2, M)
    Qr has shape (N, N)
    ql[0] is the diagonal part of Ql
    ql[1,0:-1] is the last column of Ql, excluding the last entry
    dX is perturbation of (matrix) parameter
    dG is perturbation of (matrix) gradient
    step: update step size normalized to range [0, 1] 
    _tiny: an offset to avoid division by zero  
    """
    # make sure that Ql and Qr have similar dynamic range
    max_l = torch.max(ql[0])
    max_r = torch.max(torch.diag(Qr))  
    rho = torch.sqrt(max_l/max_r)
    ql /= rho
    Qr *= rho
    
    # refer to https://arxiv.org/abs/1512.04202 for details
    A = ql[0:1].t()*dG + ql[1:].t().mm( dG[-1:] ) # Ql*dG 
    A = A.mm(Qr.t())
    
    Bt = dX/ql[0:1].t()
    Bt[-1:] -= (ql[1:]/(ql[0:1]*ql[0,-1])).mm(dX)
    Bt = torch.triangular_solve(Bt.t(), Qr, upper=True, transpose=True)[0].t()
    
    grad1_diag = torch.sum(A*A, dim=1) - torch.sum(Bt*Bt, dim=1)
    grad1_bias = A[:-1].mm(A[-1:].t()) - Bt[:-1].mm(Bt[-1:].t()) 
    grad1_bias = torch.cat([torch.squeeze(grad1_bias), grad1_bias.new_zeros(1)])  

    step1 = step/(torch.max(torch.max(torch.abs(grad1_diag)), 
                            torch.max(torch.abs(grad1_bias))) + _tiny)
    new_ql0 = ql[0] - step1*grad1_diag*ql[0]
    new_ql1 = ql[1] - step1*(grad1_diag*ql[1] + ql[0,-1]*grad1_bias)
    
    grad2 = torch.triu(A.t().mm(A) - Bt.t().mm(Bt))
    step2 = step/(torch.max(torch.abs(grad2)) + _tiny)
    
    return torch.stack((new_ql0, new_ql1)), Qr - step2*grad2.mm(Qr)
Exemplo n.º 21
0
    def posterior(self, t, x, obs_t, obs_y, generator=None, noise=False):
        # print(self.name, self.kernel, 'posterior', noise)
        nobs = obs_t.shape[0]
        if noise:
            kernel = self.kernel_noise
        else:
            kernel = self.kernel
        cov = self.kernel(obs_t, t)
        sigma = torch.cat([
            torch.cat([self.kernel_noise(obs_t), cov], dim=2),
            torch.cat([cov.transpose(1, 2), kernel(t)], dim=2)
        ],
                          dim=1)
        cho = cholesky(sigma)
        cho_cross = cho[:, nobs:, :nobs]
        cho_obs = cho[:, :nobs, :nobs]
        cho_bar = cho[:, nobs:, nobs:]

        cho_solve = torch.triangular_solve(obs_y, cho_obs, upper=False)[0]
        return cho_cross.matmul(cho_solve) + torch.matmul(cho_bar, x)
Exemplo n.º 22
0
    def forward(self, input):
        "Compute covariance"
        if self.demean:
            if self.mu is None:
                temp = input - torch.mean(input, 0)
            else:
                temp = input - self.mu
        else:
            temp = input

        if self.R is None:
            cov = temp.transpose(
                1, 0) @ temp / temp.shape[0]  # compute covariance matrix
            cov = (cov + cov.transpose(1, 0)) / 2 + 1e-5 * torch.eye(
                cov.shape[0], device=self.device)
            R = torch.cholesky(cov)  # returns the lower cholesky matrix
        else:
            R = self.R
        Y, _ = torch.triangular_solve(temp.transpose(1, 0), R, upper=False)
        return Y.transpose(1, 0)
Exemplo n.º 23
0
def update_precond_dense(Q, dxs, dgs, step=0.01, _tiny=1.2e-38):
    # type: (Tensor, List[Tensor], List[Tensor], float, float) -> Tensor
    """
    update dense preconditioner P = Q^T*Q
    Q: Cholesky factor of preconditioner with positive diagonal entries 
    dxs: list of perturbations of parameters
    dgs: list of perturbations of gradients
    step: update step size normalized to range [0, 1] 
    _tiny: an offset to avoid division by zero 
    """
    dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs])
    dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs])
    
    a = Q.mm(dg)
    b = torch.triangular_solve(dx, Q, upper=True, transpose=True)[0]

    grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))
    step0 = step/(grad.abs().max() + _tiny)        
        
    return Q - step0*grad.mm(Q)
Exemplo n.º 24
0
    def backward(ctx, grad_output):
        jitter = 1.0e-8  # do i really need this?
        z, epsilon, L = ctx.saved_tensors

        dim = L.shape[0]
        g = grad_output
        loc_grad = sum_leftmost(grad_output, -1)

        identity = eye_like(g, dim)
        R_inv = torch.triangular_solve(identity,
                                       L.t(),
                                       transpose=False,
                                       upper=True)[0]

        z_ja = z.unsqueeze(-1)
        g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2)
        epsilon_jb = epsilon.unsqueeze(-2)
        g_ja = g.unsqueeze(-1)
        diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2)

        Sigma_inv = torch.mm(R_inv, R_inv.t())
        V, D, _ = torch.svd(Sigma_inv + jitter)
        D_outer = D.unsqueeze(-1) + D.unsqueeze(0)

        expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim])
        z_tilde = identity * torch.matmul(
            z, V).unsqueeze(-1).expand(*expand_tuple)
        g_tilde = identity * torch.matmul(
            g, V).unsqueeze(-1).expand(*expand_tuple)

        Y = sum_leftmost(
            torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2)
        Y = torch.mm(V, torch.mm(Y, V.t()))
        Y = Y + Y.t()

        Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(
            Y, torch.mm(Sigma_inv, R_inv))
        diff_L_ab += 0.5 * Tr_xi_Y
        L_grad = torch.tril(diff_L_ab)

        return loc_grad, L_grad, None
Exemplo n.º 25
0
    def covar_cache(self):
        # Here, the covar_cache is going to be the inverse of K_{XX} + \sigma^2 I
        # This is easily computed using Woodbury
        train_train_covar = self.lik_train_train_covar.evaluate_kernel()

        # Get terms needed for woodbury
        root = train_train_covar._lazy_tensor.root_decomposition().root
        inv_diag = train_train_covar._diag_tensor.inverse()

        # Form LT using woodbury
        ones = torch.tensor(1.0, dtype=inv_diag.dtype, device=inv_diag.device)
        chol_factor = (root.transpose(-1, -2)
                       @ root).add_diag(ones).cholesky().evaluate()
        woodbury_term = torch.triangular_solve(
            inv_diag.diag().unsqueeze(-2) * root.evaluate().transpose(-1, -2),
            chol_factor,
            upper=False)[0]
        inverse = AddedDiagLazyTensor(
            MatmulLazyTensor(woodbury_term.transpose(-1, -2), -woodbury_term),
            inv_diag)
        return inverse
def sequential_thinning_dpp_init(K):
    N = K.size(0)
    try:
        L = torch.cholesky(torch.eye(N - 1, device=mydevice) - K[:-1, :-1],
                           upper=False)
        B = torch.triu(K[0:-1, 1:])
        q = K.diag()
        q[1:].add_(
            torch.sum(
                torch.triu(torch.triangular_solve(B, L, upper=False)[0])**2,
                0))
    except RuntimeError:  #slower procedure if I-K is singular
        q = torch.ones([N])
        ImK = torch.eye(K.size(0), device=mydevice) - K
        q[0] = K[0, 0]
        for k in range(1, N):
            q[k] = K[k, k] + torch.mm(
                K[[k], :k], torch.cholesky_solve(K[:k, [k]], ImK[:k, :k]))
            if q[k] == 1:
                break
    return (q)
def update_precond_scaw(Ql, qr, dX, dG, step=0.01):
    """
    update scaling-and-whitening preconditioner
    """
    max_l = torch.max(torch.abs(Ql))
    max_r = torch.max(torch.abs(qr))

    rho = torch.sqrt(max_l / max_r)
    Ql = Ql / rho
    qr = rho * qr

    A = Ql.mm(dG * qr)
    Bt = torch.triangular_solve(dX / qr, Ql.t(), upper=False)[0]

    grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))
    grad2 = torch.sum(A * A, dim=0, keepdim=True) - torch.sum(Bt * Bt, dim=0, keepdim=True)

    step1 = step / (torch.max(torch.abs(grad1)) + _tiny)
    step2 = step / (torch.max(torch.abs(grad2)) + _tiny)

    return Ql - step1 * grad1.mm(Ql), qr - step2 * grad2 * qr
Exemplo n.º 28
0
    def eKxz_parallel(self, Z, Xmean, Xcov):
        # TODO: add test
        """Parallel implementation (needs more space, but less time)
        Refer to GPflow implementation

        Args:
            Args:
            Z (Variable): m x q inducing input
            Xmean (Variable): n x q mean of input X
            Xcov (Varible): posterior covariance of X
                two sizes are accepted:
                    n x q x q: each q(x_i) has full covariance
                    n x q: each q(x_i) has diagonal covariance (uncorrelated),
                        stored in each row
        Returns:
            (Variable): n x m
        """

        # Revisit later, check for backward support for n-D tensor
        n = Xmean.size(0)
        q = Xmean.size(1)
        m = Z.size(0)
        if Xcov.dim() == 2:
            # from flattered diagonal to full matrix
            cov = Variable(th.Tensor(n, q, q).type(float_type))
            for i in range(Xmean.size(0)):
                cov[i] = Xcov[i].diag()
            Xcov = cov
            del cov
        length_scales = self.length_scales.transform()
        Lambda = length_scales.pow(2).diag().unsqueeze(0).expand_as(Xcov)
        L = cholesky(Lambda + Xcov)
        xz = Xmean.unsqueeze(2).expand(n, q, m) - Z.unsqueeze(0).expand(
            n, q, m)
        Lxz = th.triangular_solve(xz, L, upper=False)[0]
        half_log_dets = L.diag().log().sum(1) \
                        - length_scales.log().sum().expand(n)

        return self.variance.transform().expand(n, m) \
               * th.exp(-0.5 * Lxz.pow(2).sum(1) - half_log_dets.expand(n, m))
Exemplo n.º 29
0
def process_gaussian(A: torch.Tensor, scale: float, inverted: bool):
    """
    Convert a Gaussian x = scale*A z, with z ~ N(0, I), into a linear autoregressive model
    Calling this function with (A, scale=s) is equivalent to calling it with (s*A, scale=1),
    except with possibly better numerical stability.
    """
    assert isinstance(A, torch.Tensor) and len(
        A.shape) == 2 and A.shape[0] == A.shape[1]
    dim = A.shape[0]
    if inverted:
        # Here, the goal is to convert x = A^{-1} z into a linear AR model.
        # If R (upper triangular) comes from the QR decomposition A, then
        # R^T R = A^T A, so A^{-1} A^{-1}^T = R^{-1} R^{-1}^T,
        # which implies that R^{-1} z has the same distribution as A^{-1} z.
        # So, if x solves the equation R x = z, then x will have the distribution we desire.
        # Because R is triangular, x can be obtained via back substitution, which is equivalent to sampling from
        # a linear AR model driven by Gaussian noise z. The coefficients of the linear functions that
        # define the conditionals of this AR model can be read off in the rows of R; the ordering of the
        # variables goes backwards because R is upper triangular.
        _, R = torch.qr(A)
        stds = (1.0 /
                scale) / R.diag().abs()  # standard deviations of conditionals
        R /= -R.diag(
        )[:, None]  # coefficients of linear functions defining the means
        mean_coefs = R
    else:
        # Here, the goal is to convert x = A z into a linear AR model
        L = torch.qr(A.t())[1].t()  # Cholesky decomposition of AA^T
        stds = L.diag().abs() * scale  # standard deviations of conditionals
        Linv, _ = torch.triangular_solve(torch.eye(dim,
                                                   dtype=L.dtype,
                                                   device=L.device),
                                         L,
                                         upper=False)  # invert L
        Linv *= -L.diag()[:, None]
        mean_coefs = Linv
    mean_coefs[range(dim), range(
        dim
    )] = 0  # set diagonal to zero; AR conditionals don't depend on current timestep
    return mean_coefs, stds
Exemplo n.º 30
0
def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
    if p.event_shape != q.event_shape:
        raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
                          different event shapes cannot be computed")

    term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
                                   q._capacitance_tril) -
             2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
    term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
                                       q.loc - p.loc,
                                       q._capacitance_tril)
    # Expands term2 according to
    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
    #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
    qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
                 q._unbroadcasted_cov_diag.unsqueeze(-2))
    A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
    term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
                              q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
    term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
    term2 = term21 - term22
    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])