Exemplo n.º 1
0
def exactsolve(A: LinearOperator, B: torch.Tensor,
               E: Union[torch.Tensor, None], M: Union[LinearOperator, None]):
    """
    Solve the linear equation by contructing the full matrix of LinearOperators.

    Warnings
    --------
    * As this method construct the linear operators explicitly, it might requires
      a large memory.
    """
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    if E is None:
        Amatrix = A.fullmatrix()  # (*BA, na, na)
        x, _ = torch.solve(B, Amatrix)  # (*BAB, na, ncols)
    elif M is None:
        Amatrix = A.fullmatrix()
        x = _solve_ABE(Amatrix, B, E)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, na, na)
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, na, na)
        Linv = torch.inverse(L)  # (*BM, na, na)
        LinvT = Linv.transpose(-2, -1)  # (*BM, na, na)
        A2 = torch.matmul(Linv, A.mm(LinvT))  # (*BAM, na, na)
        B2 = torch.matmul(Linv, B)  # (*BBM, na, ncols)

        X2 = _solve_ABE(A2, B2, E)  # (*BABEM, na, ncols)
        x = torch.matmul(LinvT, X2)  # (*BABEM, na, ncols)
    return x
Exemplo n.º 2
0
def exacteig(A: LinearOperator, neig: int,
             mode: str, M: Optional[LinearOperator]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Eigendecomposition using explicit matrix construction.
    No additional option for this method.

    Warnings
    --------
    * As this method construct the linear operators explicitly, it might requires
      a large memory.
    """
    Amatrix = A.fullmatrix()  # (*BA, q, q)
    if M is None:
        evals, evecs = torch.symeig(Amatrix, eigenvectors=True)  # (*BA, q), (*BA, q, q)
        return _take_eigpairs(evals, evecs, neig, mode)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, q, q)

        # M decomposition to make A symmetric
        # it is done this way to make it numerically stable in avoiding
        # complex eigenvalues for (near-)degenerate case
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, q, q)
        Linv = torch.inverse(L)  # (*BM, q, q)
        LinvT = Linv.transpose(-2, -1)  # (*BM, q, q)
        A2 = torch.matmul(Linv, torch.matmul(Amatrix, LinvT))  # (*BAM, q, q)

        # calculate the eigenvalues and eigenvectors
        # (the eigvecs are normalized in M-space)
        evals, evecs = torch.symeig(A2, eigenvectors=True)  # (*BAM, q, q)
        evals, evecs = _take_eigpairs(evals, evecs, neig, mode)  # (*BAM, neig) and (*BAM, q, neig)
        evecs = torch.matmul(LinvT, evecs)
        return evals, evecs