Пример #1
0
def exactsolve(A: LinearOperator, B: torch.Tensor,
               E: Union[torch.Tensor, None], M: Union[LinearOperator, None]):
    """
    Solve the linear equation by contructing the full matrix of LinearOperators.

    Warnings
    --------
    * As this method construct the linear operators explicitly, it might requires
      a large memory.
    """
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    if E is None:
        Amatrix = A.fullmatrix()  # (*BA, na, na)
        x, _ = torch.solve(B, Amatrix)  # (*BAB, na, ncols)
    elif M is None:
        Amatrix = A.fullmatrix()
        x = _solve_ABE(Amatrix, B, E)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, na, na)
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, na, na)
        Linv = torch.inverse(L)  # (*BM, na, na)
        LinvT = Linv.transpose(-2, -1)  # (*BM, na, na)
        A2 = torch.matmul(Linv, A.mm(LinvT))  # (*BAM, na, na)
        B2 = torch.matmul(Linv, B)  # (*BBM, na, ncols)

        X2 = _solve_ABE(A2, B2, E)  # (*BABEM, na, ncols)
        x = torch.matmul(LinvT, X2)  # (*BABEM, na, ncols)
    return x
Пример #2
0
def _setup_linear_problem(A: LinearOperator, B: torch.Tensor,
                          E: Optional[torch.Tensor], M: Optional[LinearOperator],
                          batchdims: Sequence[int],
                          posdef: Optional[bool],
                          need_hermit: bool) -> \
        Tuple[Callable[[torch.Tensor], torch.Tensor],
              Callable[[torch.Tensor], torch.Tensor],
              torch.Tensor, bool]:

    # get the linear operator (including the MXE part)
    if E is None:
        A_fcn = lambda x: A.mm(x)
        AT_fcn = lambda x: A.rmm(x)
        B_new = B
        col_swapped = False
    else:
        # A: (*BA, nr, nr) linop
        # B: (*BB, nr, ncols)
        # E: (*BE, ncols)
        # M: (*BM, nr, nr) linop
        if M is None:
            BAs, BBs, BEs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2],
                                                 E.shape[:-1])
        else:
            BAs, BBs, BEs, BMs = normalize_bcast_dims(A.shape[:-2],
                                                      B.shape[:-2],
                                                      E.shape[:-1],
                                                      M.shape[:-2])
        E = E.reshape(*BEs, *E.shape[-1:])
        E_new = E.unsqueeze(0).transpose(-1, 0).unsqueeze(
            -1)  # (ncols, *BEs, 1, 1)
        B = B.reshape(*BBs, *B.shape[-2:])  # (*BBs, nr, ncols)
        B_new = B.unsqueeze(0).transpose(-1, 0)  # (ncols, *BBs, nr, 1)

        def A_fcn(x):
            # x: (ncols, *BX, nr, 1)
            Ax = A.mm(x)  # (ncols, *BAX, nr, 1)
            Mx = M.mm(x) if M is not None else x  # (ncols, *BMX, nr, 1)
            MxE = Mx * E_new  # (ncols, *BMXE, nr, 1)
            return Ax - MxE

        def AT_fcn(x):
            # x: (ncols, *BX, nr, 1)
            ATx = A.rmm(x)
            MTx = M.rmm(x) if M is not None else x
            MTxE = MTx * E_new
            return ATx - MTxE

        col_swapped = True

    # estimate if it's posdef with power iteration
    if need_hermit:
        is_hermit = A.is_hermitian and (M is None or M.is_hermitian)
        if not is_hermit:
            posdef = False
    if posdef is None:
        nr, ncols = B.shape[-2:]
        x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims,
                                                                  nr, ncols)
        x0 = torch.randn(x0shape, dtype=A.dtype, device=A.device)
        x0 = x0 / x0.norm(dim=-2, keepdim=True)
        largest_eival = _get_largest_eival(A_fcn, x0)  # (*, 1, nc)
        negeival = largest_eival <= 0

        # if the largest eigenvalue is negative, then it's not posdef
        if torch.all(negeival):
            posdef = False
        else:
            offset = torch.clamp(largest_eival, min=0.0)
            A_fcn2 = lambda x: A_fcn(x) - offset * x
            mostneg_eival = _get_largest_eival(A_fcn2, x0)  # (*, 1, nc)
            posdef = bool(
                torch.all(torch.logical_or(-mostneg_eival <= offset,
                                           negeival)).item())

    # get the linear operation if it is not a posdef (A -> AT.A)
    if posdef:
        return A_fcn, AT_fcn, B_new, col_swapped
    else:

        def A_new_fcn(x):
            return AT_fcn(A_fcn(x))

        B2 = AT_fcn(B_new)
        return A_new_fcn, A_new_fcn, B2, col_swapped
Пример #3
0
def svd(A: LinearOperator,
        k: Optional[int] = None,
        mode: str = "uppest",
        bck_options: Mapping[str, Any] = {},
        method: Union[str, Callable, None] = None,
        **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""
    Perform the singular value decomposition (SVD):

    .. math::

        \mathbf{A} = \mathbf{U\Sigma V}^H

    where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and
    :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative
    numbers.
    This function can handle derivatives for degenerate singular values by setting non-zero
    ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions
    in [1]_.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator to be decomposed. It has a shape of ``(*BA, m, n)``
        where ``(*BA)`` is the batched dimension of ``A``.
    k: int or None
        The number of decomposition obtained. If ``None``, it will be
        ``min(*A.shape[-2:])``
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``k`` decomposition.
        If ``"uppest"``, it will take the uppermost ``k``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation with some additional arguments for computing the backward
        derivatives:

        * ``degen_atol`` (``float`` or None): Minimum absolute difference between
          two singular values to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)
        * ``degen_rtol`` (``float`` or None): Minimum relative difference between
          two singular values to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)

        Note: the default values of ``degen_atol`` and ``degen_rtol`` are going
        to change in the future. So, for future compatibility, please specify
        the specific values.

    method: str or callable or None
        Method for the svd (same options for :func:`symeig`). If ``None``,
        it will choose ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (u, s, vh)
        It will return ``u, s, vh`` with shapes respectively
        ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``.

    Note
    ----
    It is a naive implementation of symmetric eigendecomposition of ``A.H @ A``
    or ``A @ A.H`` (depending which one is cheaper)

    References
    ----------
    .. [1] Muhammad F. Kasim,
           "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases".
           arXiv:2011.04366 (2020)
           `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_
    """
    # A: (*BA, m, n)
    # adapted from scipy.sparse.linalg.svds

    if is_debug_enabled():
        A.check()
    BA = A.shape[:-2]

    m = A.shape[-2]
    n = A.shape[-1]
    if m < n:
        AAsym = A.matmul(A.H, is_hermitian=True)
        min_nm = m
    else:
        AAsym = A.H.matmul(A, is_hermitian=True)
        min_nm = n

    eivals, eivecs = symeig(AAsym,
                            k,
                            mode,
                            bck_options=bck_options,
                            method=method,
                            **fwd_options)  # (*BA, k) and (*BA, min(mn), k)

    # clamp the eigenvalues to a small positive values to avoid numerical
    # instability
    eivals = torch.clamp(eivals, min=0.0)
    s = torch.sqrt(eivals)  # (*BA, k)
    sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2)  # (*BA, 1, k)
    if m < n:
        u = eivecs  # (*BA, m, k)
        v = A.rmm(u) / sdiv  # (*BA, n, k)
    else:
        v = eivecs  # (*BA, n, k)
        u = A.mm(v) / sdiv  # (*BA, m, k)
    vh = v.transpose(-2, -1)
    return u, s, vh
Пример #4
0
def svd(A: LinearOperator,
        k: Optional[int] = None,
        mode: str = "uppest",
        bck_options: Mapping[str, Any] = {},
        method: Union[str, Callable, None] = None,
        **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""
    Perform the singular value decomposition (SVD):

    .. math::

        \mathbf{A} = \mathbf{U\Sigma V}^H

    where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and
    :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative
    numbers.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator to be decomposed. It has a shape of ``(*BA, m, n)``
        where ``(*BA)`` is the batched dimension of ``A``.
    k: int or None
        The number of decomposition obtained. If ``None``, it will be
        ``min(*A.shape[-2:])``
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``k`` decomposition.
        If ``"uppest"``, it will take the uppermost ``k``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation.
    method: str or callable or None
        Method for the svd (same options for :func:`symeig`). If ``None``,
        it will choose ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (u, s, vh)
        It will return ``u, s, vh`` with shapes respectively
        ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``.

    Note
    ----
    It is a naive implementation of symmetric eigendecomposition of ``A.H @ A``
    or ``A @ A.H`` (depending which one is cheaper)

    Warnings
    --------
    * If ``s`` contains very small numbers or degenerate values, the
      calculation and its gradient might be inaccurate.
    * The second derivative through U or V might be unstable.
      Extra care must be taken.
    """
    # A: (*BA, m, n)
    # adapted from scipy.sparse.linalg.svds

    if is_debug_enabled():
        A.check()
    BA = A.shape[:-2]

    m = A.shape[-2]
    n = A.shape[-1]
    if m < n:
        AAsym = A.matmul(A.H, is_hermitian=True)
        min_nm = m
    else:
        AAsym = A.H.matmul(A, is_hermitian=True)
        min_nm = n

    eivals, eivecs = symeig(AAsym,
                            k,
                            mode,
                            bck_options=bck_options,
                            method=method,
                            **fwd_options)  # (*BA, k) and (*BA, min(mn), k)

    # clamp the eigenvalues to a small positive values to avoid numerical
    # instability
    eivals = torch.clamp(eivals, min=0.0)
    s = torch.sqrt(eivals)  # (*BA, k)
    sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2)  # (*BA, 1, k)
    if m < n:
        u = eivecs  # (*BA, m, k)
        v = A.rmm(u) / sdiv  # (*BA, n, k)
    else:
        v = eivecs  # (*BA, n, k)
        u = A.mm(v) / sdiv  # (*BA, m, k)
    vh = v.transpose(-2, -1)
    return u, s, vh
Пример #5
0
def davidson(A: LinearOperator, neig: int,
             mode: str,
             M: Optional[LinearOperator] = None,
             max_niter: int = 1000,
             nguess: Optional[int] = None,
             v_init: str = "randn",
             max_addition: Optional[int] = None,
             min_eps: float = 1e-6,
             verbose: bool = False,
             **unused) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Using Davidson method for large sparse matrix eigendecomposition [1]_.

    Arguments
    ---------
    max_niter: int
        Maximum number of iterations
    v_init: str
        Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``)
    max_addition: int or None
        Maximum number of new guesses to be added to the collected vectors.
        If None, set to ``neig``.
    min_eps: float
        Minimum residual error to be stopped
    verbose: bool
        Option to be verbose

    References
    ----------
    .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems"
           http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf
    """
    # TODO: optimize for large linear operator and strict min_eps
    # Ideas:
    # (1) use better strategy to get the estimate on eigenvalues
    # (2) use restart strategy

    if nguess is None:
        nguess = neig
    if max_addition is None:
        max_addition = neig

    # get the shape of the transformation
    na = A.shape[-1]
    if M is None:
        bcast_dims = A.shape[:-2]
    else:
        bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2])
    dtype = A.dtype
    device = A.device

    prev_eigvals = None
    prev_eigvalT = None
    stop_reason = "max_niter"
    shift_is_eigvalT = False
    idx = torch.arange(neig).unsqueeze(-1)  # (neig, 1)

    # set up the initial guess
    V = _set_initial_v(v_init.lower(), dtype, device,
                       bcast_dims, na, nguess,
                       M=M)  # (*BAM, na, nguess)

    best_resid: Union[float, torch.Tensor] = float("inf")
    AV = A.mm(V)
    for i in range(max_niter):
        VT = V.transpose(-2, -1)  # (*BAM,nguess,na)
        # Can be optimized by saving AV from the previous iteration and only
        # operate AV for the new V. This works because the old V has already
        # been orthogonalized, so it will stay the same
        # AV = A.mm(V) # (*BAM,na,nguess)
        T = torch.matmul(VT, AV)  # (*BAM,nguess,nguess)

        # eigvals are sorted from the lowest
        # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess)
        eigvalT, eigvecT = torch.symeig(T, eigenvectors=True)
        eigvalT, eigvecT = _take_eigpairs(eigvalT, eigvecT, neig, mode)  # (*BAM, neig) and (*BAM, nguess, neig)

        # calculate the eigenvectors of A
        eigvecA = torch.matmul(V, eigvecT)  # (*BAM, na, neig)

        # calculate the residual
        AVs = torch.matmul(AV, eigvecT)  # (*BAM, na, neig)
        LVs = eigvalT.unsqueeze(-2) * eigvecA  # (*BAM, na, neig)
        if M is not None:
            LVs = M.mm(LVs)
        resid = AVs - LVs  # (*BAM, na, neig)

        # print information and check convergence
        max_resid = resid.abs().max()
        if prev_eigvalT is not None:
            deigval = eigvalT - prev_eigvalT
            max_deigval = deigval.abs().max()
            if verbose:
                print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" %
                      (i + 1, nguess, max_resid, max_deigval))  # type:ignore

        if max_resid < best_resid:
            best_resid = max_resid
            best_eigvals = eigvalT
            best_eigvecs = eigvecA
        if max_resid < min_eps:
            break
        if AV.shape[-1] == AV.shape[-2]:
            break
        prev_eigvalT = eigvalT

        # apply the preconditioner
        t = -resid  # (*BAM, na, neig)

        # orthogonalize t with the rest of the V
        t = to_fortran_order(t)
        Vnew = torch.cat((V, t), dim=-1)
        if Vnew.shape[-1] > Vnew.shape[-2]:
            Vnew = Vnew[..., :Vnew.shape[-2]]
        nadd = Vnew.shape[-1] - V.shape[-1]
        nguess = nguess + nadd
        if M is not None:
            MV_ = M.mm(Vnew)
            V, R = tallqr(Vnew, MV=MV_)
        else:
            V, R = tallqr(Vnew)
        AVnew = A.mm(V[..., -nadd:])  # (*BAM,na,nadd)
        AVnew = to_fortran_order(AVnew)
        AV = torch.cat((AV, AVnew), dim=-1)

    eigvals = best_eigvals  # (*BAM, neig)
    eigvecs = best_eigvecs  # (*BAM, na, neig)
    return eigvals, eigvecs