def exactsolve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): """ Solve the linear equation by contructing the full matrix of LinearOperators. Warnings -------- * As this method construct the linear operators explicitly, it might requires a large memory. """ # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) if E is None: Amatrix = A.fullmatrix() # (*BA, na, na) x, _ = torch.solve(B, Amatrix) # (*BAB, na, ncols) elif M is None: Amatrix = A.fullmatrix() x = _solve_ABE(Amatrix, B, E) else: Mmatrix = M.fullmatrix() # (*BM, na, na) L = torch.cholesky(Mmatrix, upper=False) # (*BM, na, na) Linv = torch.inverse(L) # (*BM, na, na) LinvT = Linv.transpose(-2, -1) # (*BM, na, na) A2 = torch.matmul(Linv, A.mm(LinvT)) # (*BAM, na, na) B2 = torch.matmul(Linv, B) # (*BBM, na, ncols) X2 = _solve_ABE(A2, B2, E) # (*BABEM, na, ncols) x = torch.matmul(LinvT, X2) # (*BABEM, na, ncols) return x
def exacteig(A: LinearOperator, neig: int, mode: str, M: Optional[LinearOperator]) -> Tuple[torch.Tensor, torch.Tensor]: """ Eigendecomposition using explicit matrix construction. No additional option for this method. Warnings -------- * As this method construct the linear operators explicitly, it might requires a large memory. """ Amatrix = A.fullmatrix() # (*BA, q, q) if M is None: evals, evecs = torch.symeig(Amatrix, eigenvectors=True) # (*BA, q), (*BA, q, q) return _take_eigpairs(evals, evecs, neig, mode) else: Mmatrix = M.fullmatrix() # (*BM, q, q) # M decomposition to make A symmetric # it is done this way to make it numerically stable in avoiding # complex eigenvalues for (near-)degenerate case L = torch.cholesky(Mmatrix, upper=False) # (*BM, q, q) Linv = torch.inverse(L) # (*BM, q, q) LinvT = Linv.transpose(-2, -1) # (*BM, q, q) A2 = torch.matmul(Linv, torch.matmul(Amatrix, LinvT)) # (*BAM, q, q) # calculate the eigenvalues and eigenvectors # (the eigvecs are normalized in M-space) evals, evecs = torch.symeig(A2, eigenvectors=True) # (*BAM, q, q) evals, evecs = _take_eigpairs(evals, evecs, neig, mode) # (*BAM, neig) and (*BAM, q, neig) evecs = torch.matmul(LinvT, evecs) return evals, evecs