def exactsolve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): """ Solve the linear equation by contructing the full matrix of LinearOperators. Warnings -------- * As this method construct the linear operators explicitly, it might requires a large memory. """ # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) if E is None: Amatrix = A.fullmatrix() # (*BA, na, na) x, _ = torch.solve(B, Amatrix) # (*BAB, na, ncols) elif M is None: Amatrix = A.fullmatrix() x = _solve_ABE(Amatrix, B, E) else: Mmatrix = M.fullmatrix() # (*BM, na, na) L = torch.cholesky(Mmatrix, upper=False) # (*BM, na, na) Linv = torch.inverse(L) # (*BM, na, na) LinvT = Linv.transpose(-2, -1) # (*BM, na, na) A2 = torch.matmul(Linv, A.mm(LinvT)) # (*BAM, na, na) B2 = torch.matmul(Linv, B) # (*BBM, na, ncols) X2 = _solve_ABE(A2, B2, E) # (*BABEM, na, ncols) x = torch.matmul(LinvT, X2) # (*BABEM, na, ncols) return x
def _setup_linear_problem(A: LinearOperator, B: torch.Tensor, E: Optional[torch.Tensor], M: Optional[LinearOperator], batchdims: Sequence[int], posdef: Optional[bool], need_hermit: bool) -> \ Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor], torch.Tensor, bool]: # get the linear operator (including the MXE part) if E is None: A_fcn = lambda x: A.mm(x) AT_fcn = lambda x: A.rmm(x) B_new = B col_swapped = False else: # A: (*BA, nr, nr) linop # B: (*BB, nr, ncols) # E: (*BE, ncols) # M: (*BM, nr, nr) linop if M is None: BAs, BBs, BEs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1]) else: BAs, BBs, BEs, BMs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1], M.shape[:-2]) E = E.reshape(*BEs, *E.shape[-1:]) E_new = E.unsqueeze(0).transpose(-1, 0).unsqueeze( -1) # (ncols, *BEs, 1, 1) B = B.reshape(*BBs, *B.shape[-2:]) # (*BBs, nr, ncols) B_new = B.unsqueeze(0).transpose(-1, 0) # (ncols, *BBs, nr, 1) def A_fcn(x): # x: (ncols, *BX, nr, 1) Ax = A.mm(x) # (ncols, *BAX, nr, 1) Mx = M.mm(x) if M is not None else x # (ncols, *BMX, nr, 1) MxE = Mx * E_new # (ncols, *BMXE, nr, 1) return Ax - MxE def AT_fcn(x): # x: (ncols, *BX, nr, 1) ATx = A.rmm(x) MTx = M.rmm(x) if M is not None else x MTxE = MTx * E_new return ATx - MTxE col_swapped = True # estimate if it's posdef with power iteration if need_hermit: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) if not is_hermit: posdef = False if posdef is None: nr, ncols = B.shape[-2:] x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims, nr, ncols) x0 = torch.randn(x0shape, dtype=A.dtype, device=A.device) x0 = x0 / x0.norm(dim=-2, keepdim=True) largest_eival = _get_largest_eival(A_fcn, x0) # (*, 1, nc) negeival = largest_eival <= 0 # if the largest eigenvalue is negative, then it's not posdef if torch.all(negeival): posdef = False else: offset = torch.clamp(largest_eival, min=0.0) A_fcn2 = lambda x: A_fcn(x) - offset * x mostneg_eival = _get_largest_eival(A_fcn2, x0) # (*, 1, nc) posdef = bool( torch.all(torch.logical_or(-mostneg_eival <= offset, negeival)).item()) # get the linear operation if it is not a posdef (A -> AT.A) if posdef: return A_fcn, AT_fcn, B_new, col_swapped else: def A_new_fcn(x): return AT_fcn(A_fcn(x)) B2 = AT_fcn(B_new) return A_new_fcn, A_new_fcn, B2, col_swapped
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. This function can handle derivatives for degenerate singular values by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) Warnings -------- * If ``s`` contains very small numbers or degenerate values, the calculation and its gradient might be inaccurate. * The second derivative through U or V might be unstable. Extra care must be taken. """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def davidson(A: LinearOperator, neig: int, mode: str, M: Optional[LinearOperator] = None, max_niter: int = 1000, nguess: Optional[int] = None, v_init: str = "randn", max_addition: Optional[int] = None, min_eps: float = 1e-6, verbose: bool = False, **unused) -> Tuple[torch.Tensor, torch.Tensor]: """ Using Davidson method for large sparse matrix eigendecomposition [1]_. Arguments --------- max_niter: int Maximum number of iterations v_init: str Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``) max_addition: int or None Maximum number of new guesses to be added to the collected vectors. If None, set to ``neig``. min_eps: float Minimum residual error to be stopped verbose: bool Option to be verbose References ---------- .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems" http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf """ # TODO: optimize for large linear operator and strict min_eps # Ideas: # (1) use better strategy to get the estimate on eigenvalues # (2) use restart strategy if nguess is None: nguess = neig if max_addition is None: max_addition = neig # get the shape of the transformation na = A.shape[-1] if M is None: bcast_dims = A.shape[:-2] else: bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2]) dtype = A.dtype device = A.device prev_eigvals = None prev_eigvalT = None stop_reason = "max_niter" shift_is_eigvalT = False idx = torch.arange(neig).unsqueeze(-1) # (neig, 1) # set up the initial guess V = _set_initial_v(v_init.lower(), dtype, device, bcast_dims, na, nguess, M=M) # (*BAM, na, nguess) best_resid: Union[float, torch.Tensor] = float("inf") AV = A.mm(V) for i in range(max_niter): VT = V.transpose(-2, -1) # (*BAM,nguess,na) # Can be optimized by saving AV from the previous iteration and only # operate AV for the new V. This works because the old V has already # been orthogonalized, so it will stay the same # AV = A.mm(V) # (*BAM,na,nguess) T = torch.matmul(VT, AV) # (*BAM,nguess,nguess) # eigvals are sorted from the lowest # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess) eigvalT, eigvecT = torch.symeig(T, eigenvectors=True) eigvalT, eigvecT = _take_eigpairs(eigvalT, eigvecT, neig, mode) # (*BAM, neig) and (*BAM, nguess, neig) # calculate the eigenvectors of A eigvecA = torch.matmul(V, eigvecT) # (*BAM, na, neig) # calculate the residual AVs = torch.matmul(AV, eigvecT) # (*BAM, na, neig) LVs = eigvalT.unsqueeze(-2) * eigvecA # (*BAM, na, neig) if M is not None: LVs = M.mm(LVs) resid = AVs - LVs # (*BAM, na, neig) # print information and check convergence max_resid = resid.abs().max() if prev_eigvalT is not None: deigval = eigvalT - prev_eigvalT max_deigval = deigval.abs().max() if verbose: print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % (i + 1, nguess, max_resid, max_deigval)) # type:ignore if max_resid < best_resid: best_resid = max_resid best_eigvals = eigvalT best_eigvecs = eigvecA if max_resid < min_eps: break if AV.shape[-1] == AV.shape[-2]: break prev_eigvalT = eigvalT # apply the preconditioner t = -resid # (*BAM, na, neig) # orthogonalize t with the rest of the V t = to_fortran_order(t) Vnew = torch.cat((V, t), dim=-1) if Vnew.shape[-1] > Vnew.shape[-2]: Vnew = Vnew[..., :Vnew.shape[-2]] nadd = Vnew.shape[-1] - V.shape[-1] nguess = nguess + nadd if M is not None: MV_ = M.mm(Vnew) V, R = tallqr(Vnew, MV=MV_) else: V, R = tallqr(Vnew) AVnew = A.mm(V[..., -nadd:]) # (*BAM,na,nadd) AVnew = to_fortran_order(AVnew) AV = torch.cat((AV, AVnew), dim=-1) eigvals = best_eigvals # (*BAM, neig) eigvecs = best_eigvecs # (*BAM, na, neig) return eigvals, eigvecs