def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) Warnings -------- * If ``s`` contains very small numbers or degenerate values, the calculation and its gradient might be inaccurate. * The second derivative through U or V might be unstable. Extra care must be taken. """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def _setup_linear_problem(A: LinearOperator, B: torch.Tensor, E: Optional[torch.Tensor], M: Optional[LinearOperator], batchdims: Sequence[int], posdef: Optional[bool], need_hermit: bool) -> \ Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor], torch.Tensor, bool]: # get the linear operator (including the MXE part) if E is None: A_fcn = lambda x: A.mm(x) AT_fcn = lambda x: A.rmm(x) B_new = B col_swapped = False else: # A: (*BA, nr, nr) linop # B: (*BB, nr, ncols) # E: (*BE, ncols) # M: (*BM, nr, nr) linop if M is None: BAs, BBs, BEs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1]) else: BAs, BBs, BEs, BMs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1], M.shape[:-2]) E = E.reshape(*BEs, *E.shape[-1:]) E_new = E.unsqueeze(0).transpose(-1, 0).unsqueeze( -1) # (ncols, *BEs, 1, 1) B = B.reshape(*BBs, *B.shape[-2:]) # (*BBs, nr, ncols) B_new = B.unsqueeze(0).transpose(-1, 0) # (ncols, *BBs, nr, 1) def A_fcn(x): # x: (ncols, *BX, nr, 1) Ax = A.mm(x) # (ncols, *BAX, nr, 1) Mx = M.mm(x) if M is not None else x # (ncols, *BMX, nr, 1) MxE = Mx * E_new # (ncols, *BMXE, nr, 1) return Ax - MxE def AT_fcn(x): # x: (ncols, *BX, nr, 1) ATx = A.rmm(x) MTx = M.rmm(x) if M is not None else x MTxE = MTx * E_new return ATx - MTxE col_swapped = True # estimate if it's posdef with power iteration if need_hermit: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) if not is_hermit: posdef = False if posdef is None: nr, ncols = B.shape[-2:] x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims, nr, ncols) x0 = torch.randn(x0shape, dtype=A.dtype, device=A.device) x0 = x0 / x0.norm(dim=-2, keepdim=True) largest_eival = _get_largest_eival(A_fcn, x0) # (*, 1, nc) negeival = largest_eival <= 0 # if the largest eigenvalue is negative, then it's not posdef if torch.all(negeival): posdef = False else: offset = torch.clamp(largest_eival, min=0.0) A_fcn2 = lambda x: A_fcn(x) - offset * x mostneg_eival = _get_largest_eival(A_fcn2, x0) # (*, 1, nc) posdef = bool( torch.all(torch.logical_or(-mostneg_eival <= offset, negeival)).item()) # get the linear operation if it is not a posdef (A -> AT.A) if posdef: return A_fcn, AT_fcn, B_new, col_swapped else: def A_new_fcn(x): return AT_fcn(A_fcn(x)) B2 = AT_fcn(B_new) return A_new_fcn, A_new_fcn, B2, col_swapped
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. This function can handle derivatives for degenerate singular values by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh