def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.Tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.Tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or callable or None The method of linear equation solver. If ``None``, it will choose ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry. `Note`: default method will be changed quite frequently, so if you want future compatibility, please specify a method. **fwd_options Method-specific options (see method below) Returns ------- torch.Tensor The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`. """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime( A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using solve with grad enabled") if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using solve with grad enabled") if E is not None: assert_runtime( E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn( "M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exactsolve" else: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) method = "cg" if is_hermit else "bicgstab" if method == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply(A, B, E, M, method, fwd_options, bck_options, na, *params, *mparams)
def symeig(A: LinearOperator, neig: Optional[int] = None, mode: str = "lowest", M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]: r""" Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator, .. math:: \mathbf{AX = MXE} where :math:`\mathbf{A}, \mathbf{M}` are linear operators, :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and :math:`\mathbf{X}` is a matrix containing the eigenvectors. Arguments --------- A: xitorch.LinearOperator The linear operator object on which the eigenpairs are constructed. It must be a Hermitian linear operator with shape ``(*BA, q, q)`` neig: int or None The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are retrieved mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``neig`` eigenpairs. If ``"uppest"``, it will take the uppermost ``neig``. M: xitorch.LinearOperator The transformation on the right hand side. If ``None``, then ``M=I``. If specified, it must be a Hermitian with shape ``(*BM, q, q)``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the eigendecomposition. If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (eigenvalues, eigenvectors) It will return eigenvalues and eigenvectors with shapes respectively ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the broadcasted shape of ``*BA`` and ``*BM``. """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using symeig with grad enabled") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using symeig with grad enabled") mode = mode.lower() if mode == "uppermost": mode = "uppest" if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exacteig" else: # TODO: implement robust LOBPCG and put it here method = "exacteig" if neig is None: neig = A.shape[-1] # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method == "exacteig": return exacteig(A, neig, mode, M) else: fwd_options["method"] = method # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) Warnings -------- * If ``s`` contains very small numbers or degenerate values, the calculation and its gradient might be inaccurate. * The second derivative through U or V might be unstable. Extra care must be taken. """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def symeig(A: LinearOperator, neig: Optional[int] = None, mode: str = "lowest", M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]: r""" Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator, .. math:: \mathbf{AX = MXE} where :math:`\mathbf{A}, \mathbf{M}` are linear operators, :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and :math:`\mathbf{X}` is a matrix containing the eigenvectors. This function can handle derivatives for degenerate cases by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator object on which the eigenpairs are constructed. It must be a Hermitian linear operator with shape ``(*BA, q, q)`` neig: int or None The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are retrieved mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``neig`` eigenpairs. If ``"uppest"``, it will take the uppermost ``neig``. M: xitorch.LinearOperator The transformation on the right hand side. If ``None``, then ``M=I``. If specified, it must be a Hermitian with shape ``(*BM, q, q)``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the eigendecomposition. If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (eigenvalues, eigenvectors) It will return eigenvalues and eigenvectors with shapes respectively ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the broadcasted shape of ``*BA`` and ``*BM``. References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using symeig with grad enabled") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using symeig with grad enabled") mode = mode.lower() if mode == "uppermost": mode = "uppest" if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exacteig" else: # TODO: implement robust LOBPCG and put it here method = "exacteig" if neig is None: neig = A.shape[-1] # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method == "exacteig": return exacteig(A, neig, mode, M) else: fwd_options["method"] = method # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. This function can handle derivatives for degenerate singular values by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def solve(A:LinearOperator, B:torch.Tensor, E:Union[torch.Tensor,None]=None, M:Union[LinearOperator,None]=None, posdef=False, bck_options:Mapping[str,Any]={}, method:Union[str,None]=None, **fwd_options): r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or None Indicating the method of solve. If None, it will select ``exactsolve``. **fwd_options Method-specific options (see method below) """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") if E is not None: assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn("M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: method = "exactsolve" # TODO: do a proper method selection based on the size if method == "exactsolve": return exactsolve(A, B, E, M) else: fwd_options["method"] = method # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply( A, B, E, M, posdef, fwd_options, bck_options, na, *params, *mparams)