Esempio n. 1
0
    def _test_lsymeig():
        torch.manual_seed(seed)
        mata = torch.rand(ashape, dtype=dtype, device=device)
        matm = torch.rand(mshape, dtype=dtype, device=device) + \
            torch.eye(mshape[-1], dtype=dtype, device=device)  # make sure it's not singular
        mata = mata + mata.transpose(-2, -1)
        matm = matm + matm.transpose(-2, -1)
        mata = mata.requires_grad_()
        matm = matm.requires_grad_()
        fwd_options = {"method": method}

        na = ashape[-1]
        bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
        neig = ashape[-1]

        def lsymeig_fcn(amat, mmat):
            # symmetrize
            amat = (amat + amat.transpose(-2, -1)) * 0.5
            mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
            alinop = LinearOperator.m(amat, is_hermitian=True)
            mlinop = LinearOperator.m(mmat, is_hermitian=True)
            eigvals_, eigvecs_ = lsymeig(alinop,
                                         M=mlinop,
                                         neig=neig,
                                         **fwd_options)
            return eigvals_, eigvecs_

        eival, eivec = lsymeig_fcn(mata, matm)
        loss = (eival * eival).sum() + (eivec * eivec).sum()
        # using autograd.grad instead of .backward because backward has a known
        # memory leak problem
        grads = torch.autograd.grad(loss, (mata, matm), create_graph=True)
Esempio n. 2
0
def test_solve_A_gmres():
    torch.manual_seed(seed)
    na = 3
    dtype = torch.float64
    ashape = (na, na)
    bshape = (2, na, na)
    fwd_options = {"method": "gmres"}

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

    amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype)
    bmat = torch.rand(bshape, dtype=dtype)
    amat = amat + amat.transpose(-2, -1)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def solvefcn(amat, bmat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop, B=bmat, fwd_options=fwd_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    assert torch.allclose(ax, bmat)

    gradcheck(solvefcn, (amat, bmat))
    gradgradcheck(solvefcn, (amat, bmat))
Esempio n. 3
0
def test_solve_AEM(dtype, device, abeshape, mshape, method):
    torch.manual_seed(seed)
    na = abeshape[-1]
    ashape = abeshape
    bshape = abeshape
    eshape = abeshape
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    eshape = [*eshape[:-2], ncols]
    xshape = list(
        get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1],
                         mshape[:-2])) + [na, ncols]
    fwd_options = {"method": method, "min_eps": 1e-9}
    bck_options = {
        "method": method
    }  # exactsolve at backward just to test the forward solve

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5
    bmat = torch.rand(bshape, dtype=dtype, device=device)
    emat = torch.rand(eshape, dtype=dtype, device=device)
    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

    amat = amat.requires_grad_()
    mmat = mmat.requires_grad_()
    bmat = bmat.requires_grad_()
    emat = emat.requires_grad_()

    def solvefcn(amat, mmat, bmat, emat):
        mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        x = solve(A=alinop,
                  B=bmat,
                  E=emat,
                  M=mlinop,
                  **fwd_options,
                  bck_options=bck_options)
        return x

    x = solvefcn(amat, mmat, bmat, emat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    mxe = LinearOperator.m(mmat).mm(
        torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)))
    y = ax - mxe
    assert torch.allclose(y, bmat)

    # gradient checker
    if checkgrad:
        gradcheck(solvefcn, (amat, mmat, bmat, emat))
        gradgradcheck(solvefcn, (amat, mmat, bmat, emat))
Esempio n. 4
0
def _get_batchdims(A: LinearOperator, B: torch.Tensor,
                   E: Union[torch.Tensor, None], M: Union[LinearOperator,
                                                          None]):
    batchdims = [A.shape[:-2], B.shape[:-2]]
    if E is not None:
        batchdims.append(E.shape[:-1])
        if M is not None:
            batchdims.append(M.shape[:-2])
    return get_bcasted_dims(*batchdims)
Esempio n. 5
0
 def __init__(self, a: LinearOperator, b: LinearOperator, is_hermitian: bool = False):
     shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2], b.shape[-1])
     super(MatmulLinearOperator, self).__init__(
         shape=shape,
         is_hermitian=is_hermitian,
         dtype=a.dtype,
         device=a.device,
         _suppress_hermit_warning=True,
     )
     self.a = a
     self.b = b
Esempio n. 6
0
 def __init__(self, a: LinearOperator, b: LinearOperator):
     shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2],
              b.shape[-1])
     is_hermitian = a.is_hermitian and b.is_hermitian
     super(AddLinearOperator, self).__init__(
         shape=shape,
         is_hermitian=is_hermitian,
         dtype=a.dtype,
         device=a.device,
         _suppress_hermit_warning=True,
     )
     self.a = a
     self.b = b
Esempio n. 7
0
def test_lsymeig_AM():
    torch.manual_seed(seed)
    shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)]
    # only 2 of methods, because both gradient implementations are covered
    methods = ["exacteig", "custom_exacteig"]
    dtype = torch.float64
    for ashape, mshape, method in itertools.product(shapes, shapes, methods):
        mata = torch.rand(ashape, dtype=dtype)
        matm = torch.rand(mshape, dtype=dtype) + torch.eye(
            mshape[-1], dtype=dtype)  # make sure it's not singular
        mata = mata + mata.transpose(-2, -1)
        matm = matm + matm.transpose(-2, -1)
        mata = mata.requires_grad_()
        matm = matm.requires_grad_()
        linopa = LinearOperator.m(mata, True)
        linopm = LinearOperator.m(matm, True)
        fwd_options = {"method": method}

        na = ashape[-1]
        bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
        for neig in [2, ashape[-1]]:
            eigvals, eigvecs = lsymeig(linopa,
                                       M=linopm,
                                       neig=neig,
                                       **fwd_options)  # eigvals: (..., neig)
            assert list(eigvals.shape) == list([*bshape, neig])
            assert list(eigvecs.shape) == list([*bshape, na, neig])

            ax = linopa.mm(eigvecs)
            mxe = linopm.mm(
                torch.matmul(eigvecs,
                             torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
            assert torch.allclose(ax, mxe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == ashape[-1]:

                def lsymeig_fcn(amat, mmat):
                    # symmetrize
                    amat = (amat + amat.transpose(-2, -1)) * 0.5
                    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    mlinop = LinearOperator.m(mmat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 M=mlinop,
                                                 neig=neig,
                                                 **fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mata, matm))
                gradgradcheck(lsymeig_fcn, (mata, matm))
Esempio n. 8
0
    def _test_solve():
        torch.manual_seed(seed)
        na = abeshape[-1]
        ashape = abeshape
        bshape = abeshape
        eshape = abeshape
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(
            get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1],
                             mshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {
            "method": method
        }  # exactsolve at backward just to test the forward solve

        amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
            torch.eye(ashape[-1], dtype=dtype, device=device)
        mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \
            torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5
        bmat = torch.rand(bshape, dtype=dtype, device=device)
        emat = torch.rand(eshape, dtype=dtype, device=device)
        mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        mmat = mmat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, mmat, bmat, emat):
            mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
            alinop = LinearOperator.m(amat)
            mlinop = LinearOperator.m(mmat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      M=mlinop,
                      **fwd_options,
                      bck_options=bck_options)
            return x

        loss = (solvefcn(amat, mmat, bmat, emat)**2).sum()
        # using autograd.grad instead of .backward because backward has a known
        # memory leak problem
        grads = torch.autograd.grad(loss, (amat, mmat, bmat, emat),
                                    create_graph=True)
Esempio n. 9
0
def test_solve_A():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    # custom exactsolve to check the backward implementation
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"]
    # hermitian check here to make sure the gradient propagated symmetrically
    hermits = [False, True]
    for ashape, bshape, method, hermit in itertools.product(
            shapes, shapes, methods, hermits):
        print(ashape, bshape, method, hermit)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        if hermit:
            amat = (amat + amat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()

        def solvefcn(amat, bmat):
            # is_hermitian=hermit is required to force the hermitian status in numerical gradient
            alinop = LinearOperator.m(amat, is_hermitian=hermit)
            x = solve(A=alinop,
                      B=bmat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        assert torch.allclose(ax, bmat)

        if checkgrad:
            gradcheck(solvefcn, (amat, bmat))
            gradgradcheck(solvefcn, (amat, bmat))
Esempio n. 10
0
def test_solve_AE():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"
               ]  # custom exactsolve to check the backward implementation
    dtype = torch.float64
    for ashape, bshape, eshape, method in itertools.product(
            shapes, shapes, shapes, methods):
        print(ashape, bshape, eshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                       eshape[:-1])) + [na, ncols]
        fwd_options = {"method": method}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, bmat, emat):
            alinop = LinearOperator.m(amat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
        assert torch.allclose(ax - xe, bmat)
Esempio n. 11
0
def test_lsymeig_AM(dtype, device, ashape, mshape, method):
    torch.manual_seed(seed)
    mata = torch.rand(ashape, dtype=dtype, device=device)
    matm = torch.rand(mshape, dtype=dtype, device=device) + \
        torch.eye(mshape[-1], dtype=dtype, device=device)  # make sure it's not singular
    mata = mata + mata.transpose(-2, -1)
    matm = matm + matm.transpose(-2, -1)
    mata = mata.requires_grad_()
    matm = matm.requires_grad_()
    linopa = LinearOperator.m(mata, True)
    linopm = LinearOperator.m(matm, True)
    fwd_options = {"method": method}

    na = ashape[-1]
    bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
    for neig in [2, ashape[-1]]:
        eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig,
                                   **fwd_options)  # eigvals: (..., neig)
        assert list(eigvals.shape) == list([*bshape, neig])
        assert list(eigvecs.shape) == list([*bshape, na, neig])

        ax = linopa.mm(eigvecs)
        mxe = linopm.mm(
            torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
        assert torch.allclose(ax, mxe)

        # only perform gradcheck if neig is full, to reduce the computational cost
        if neig == ashape[-1]:

            def lsymeig_fcn(amat, mmat):
                # symmetrize
                amat = (amat + amat.transpose(-2, -1)) * 0.5
                mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                alinop = LinearOperator.m(amat, is_hermitian=True)
                mlinop = LinearOperator.m(mmat, is_hermitian=True)
                eigvals_, eigvecs_ = lsymeig(alinop,
                                             M=mlinop,
                                             neig=neig,
                                             **fwd_options)
                return eigvals_, eigvecs_

            gradcheck(lsymeig_fcn, (mata, matm))
            gradgradcheck(lsymeig_fcn, (mata, matm))
Esempio n. 12
0
    def __adjoint_rmv(self, xt: torch.Tensor) -> torch.Tensor:
        # xt: (*BY, p)
        # xdummy: (*BY, q)
        # calculate the right matvec multiplication by using the adjoint trick

        BY = xt.shape[:-1]
        BA = self.shape[:-2]
        BAY = get_bcasted_dims(BY, BA)

        # calculate y = Ax
        p, q = self.shape[-2:]
        xdummy = torch.zeros((*BAY, q), dtype=xt.dtype, device=xt.device).requires_grad_()
        with torch.enable_grad():
            y = self.mv(xdummy)  # (*BAY, p)

        # calculate (dL/dx)^T = A^T (dL/dy)^T with (dL/dy)^T = xt
        xt2 = xt.contiguous().expand_as(y)  # (*BAY, p)
        res = torch.autograd.grad(y, xdummy, grad_outputs=xt2,
                                  create_graph=torch.is_grad_enabled())[0]  # (*BAY, q)
        return res
Esempio n. 13
0
def test_solve_A_methods(dtype, device, method):
    torch.manual_seed(seed)
    na = 100
    ashape = (na, na)
    bshape = (2, na, na)
    options = {
        "scipy_gmres": {},
        "broyden1": {},
        "cg": {
            "rtol":
            1e-8  # stringent rtol required to meet the torch.allclose tols
        },
        "bicgstab": {
            "rtol": 1e-8,
        },
    }[method]
    fwd_options = {"method": method, **options}

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1
    amat = (amat + amat.transpose(-2, -1)) * 0.5

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def solvefcn(amat, bmat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop, B=bmat, **fwd_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    assert torch.allclose(ax, bmat)
Esempio n. 14
0
def test_solve_A(dtype, device, ashape, bshape, method, hermit):
    torch.manual_seed(seed)
    na = ashape[-1]
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
    fwd_options = {"method": method, "min_eps": 1e-9}
    bck_options = {"method": method}

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def prepare(amat):
        if hermit:
            return (amat + amat.transpose(-2, -1)) * 0.5
        return amat

    def solvefcn(amat, bmat):
        # is_hermitian=hermit is required to force the hermitian status in numerical gradient
        alinop = LinearOperator.m(prepare(amat), is_hermitian=hermit)
        x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(prepare(amat)).mm(x)
    assert torch.allclose(ax, bmat)

    if checkgrad:
        gradcheck(solvefcn, (amat, bmat))
        gradgradcheck(solvefcn, (amat, bmat))
Esempio n. 15
0
def test_solve_AE(dtype, device, ashape, bshape, eshape, method):
    torch.manual_seed(seed)
    na = ashape[-1]
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    eshape = [*eshape[:-2], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                   eshape[:-1])) + [na, ncols]
    fwd_options = {"method": method}
    bck_options = {"method": method}

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device)
    emat = torch.rand(eshape, dtype=dtype, device=device)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()
    emat = emat.requires_grad_()

    def solvefcn(amat, bmat, emat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop,
                  B=bmat,
                  E=emat,
                  **fwd_options,
                  bck_options=bck_options)
        return x

    x = solvefcn(amat, bmat, emat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
    assert torch.allclose(ax - xe, bmat)
Esempio n. 16
0
def davidson(A: LinearOperator, neig: int,
             mode: str,
             M: Optional[LinearOperator] = None,
             max_niter: int = 1000,
             nguess: Optional[int] = None,
             v_init: str = "randn",
             max_addition: Optional[int] = None,
             min_eps: float = 1e-6,
             verbose: bool = False,
             **unused) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Using Davidson method for large sparse matrix eigendecomposition [1]_.

    Arguments
    ---------
    max_niter: int
        Maximum number of iterations
    v_init: str
        Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``)
    max_addition: int or None
        Maximum number of new guesses to be added to the collected vectors.
        If None, set to ``neig``.
    min_eps: float
        Minimum residual error to be stopped
    verbose: bool
        Option to be verbose

    References
    ----------
    .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems"
           http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf
    """
    # TODO: optimize for large linear operator and strict min_eps
    # Ideas:
    # (1) use better strategy to get the estimate on eigenvalues
    # (2) use restart strategy

    if nguess is None:
        nguess = neig
    if max_addition is None:
        max_addition = neig

    # get the shape of the transformation
    na = A.shape[-1]
    if M is None:
        bcast_dims = A.shape[:-2]
    else:
        bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2])
    dtype = A.dtype
    device = A.device

    prev_eigvals = None
    prev_eigvalT = None
    stop_reason = "max_niter"
    shift_is_eigvalT = False
    idx = torch.arange(neig).unsqueeze(-1)  # (neig, 1)

    # set up the initial guess
    V = _set_initial_v(v_init.lower(), dtype, device,
                       bcast_dims, na, nguess,
                       M=M)  # (*BAM, na, nguess)

    best_resid: Union[float, torch.Tensor] = float("inf")
    AV = A.mm(V)
    for i in range(max_niter):
        VT = V.transpose(-2, -1)  # (*BAM,nguess,na)
        # Can be optimized by saving AV from the previous iteration and only
        # operate AV for the new V. This works because the old V has already
        # been orthogonalized, so it will stay the same
        # AV = A.mm(V) # (*BAM,na,nguess)
        T = torch.matmul(VT, AV)  # (*BAM,nguess,nguess)

        # eigvals are sorted from the lowest
        # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess)
        eigvalT, eigvecT = torch.symeig(T, eigenvectors=True)
        eigvalT, eigvecT = _take_eigpairs(eigvalT, eigvecT, neig, mode)  # (*BAM, neig) and (*BAM, nguess, neig)

        # calculate the eigenvectors of A
        eigvecA = torch.matmul(V, eigvecT)  # (*BAM, na, neig)

        # calculate the residual
        AVs = torch.matmul(AV, eigvecT)  # (*BAM, na, neig)
        LVs = eigvalT.unsqueeze(-2) * eigvecA  # (*BAM, na, neig)
        if M is not None:
            LVs = M.mm(LVs)
        resid = AVs - LVs  # (*BAM, na, neig)

        # print information and check convergence
        max_resid = resid.abs().max()
        if prev_eigvalT is not None:
            deigval = eigvalT - prev_eigvalT
            max_deigval = deigval.abs().max()
            if verbose:
                print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" %
                      (i + 1, nguess, max_resid, max_deigval))  # type:ignore

        if max_resid < best_resid:
            best_resid = max_resid
            best_eigvals = eigvalT
            best_eigvecs = eigvecA
        if max_resid < min_eps:
            break
        if AV.shape[-1] == AV.shape[-2]:
            break
        prev_eigvalT = eigvalT

        # apply the preconditioner
        t = -resid  # (*BAM, na, neig)

        # orthogonalize t with the rest of the V
        t = to_fortran_order(t)
        Vnew = torch.cat((V, t), dim=-1)
        if Vnew.shape[-1] > Vnew.shape[-2]:
            Vnew = Vnew[..., :Vnew.shape[-2]]
        nadd = Vnew.shape[-1] - V.shape[-1]
        nguess = nguess + nadd
        if M is not None:
            MV_ = M.mm(Vnew)
            V, R = tallqr(Vnew, MV=MV_)
        else:
            V, R = tallqr(Vnew)
        AVnew = A.mm(V[..., -nadd:])  # (*BAM,na,nadd)
        AVnew = to_fortran_order(AVnew)
        AV = torch.cat((AV, AVnew), dim=-1)

    eigvals = best_eigvals  # (*BAM, neig)
    eigvecs = best_eigvecs  # (*BAM, na, neig)
    return eigvals, eigvecs
Esempio n. 17
0
def davidson(A, params, neig, mode, M=None, mparams=[], **options):
    """
    Iterative methods to obtain the `neig` lowest eigenvalues and eigenvectors.
    This function is written so that the backpropagation can be done.
    It solves the eigendecomposition AV = VME where V are the matrix of eigenvectors,
    and E are the diagonal matrix consists of the eigenvalues.

    Arguments
    ---------
    * A: LinearOperator instance (*BA, na, na)
        The linear operator object on which the eigenpairs are constructed.
    * params: list of differentiable torch.tensor of any shapes
        List of differentiable torch.tensor to be put to A.
    * neig: int
        The number of eigenpairs to be retrieved.
    * mode: str
        Take the `neig` "lowest" or "uppest" of the eigenpairs.
    * M: LinearOperator instance (*BM, na, na) or None
        The transformation on the right hand side. If None, then M=I.
    * mparams: list of differentiable torch.tensor of any shapes
        List of differentiable torch.tensor to be put to M.
    * **options:
        Iterative algorithm options.

    Returns
    -------
    * eigvals: torch.tensor (*BAM, neig)
    * eigvecs: torch.tensor (*BAM, na, neig)
        The `neig` lowest eigenpairs
    """
    # TODO: optimize for large linear operator and strict min_eps
    # Ideas:
    # (1) use better strategy to get the estimate on eigenvalues
    # (2) use restart strategy
    config = set_default_option(
        {
            "max_niter": 1000,
            "nguess": neig,  # number of initial guess
            "min_eps": 1e-6,
            "verbose": False,
            "eps_cond": 1e-6,
            "v_init": "randn",
            "max_addition": neig,
        },
        options)

    # get some of the options
    nguess = config["nguess"]
    max_niter = config["max_niter"]
    min_eps = config["min_eps"]
    verbose = config["verbose"]
    eps_cond = config["eps_cond"]
    max_addition = config["max_addition"]

    # get the shape of the transformation
    na = A.shape[-1]
    if M is None:
        bcast_dims = A.shape[:-2]
    else:
        bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2])
    dtype = A.dtype
    device = A.device

    # TODO: A to use params
    prev_eigvals = None
    prev_eigvalT = None
    stop_reason = "max_niter"
    shift_is_eigvalT = False
    idx = torch.arange(neig).unsqueeze(-1)  # (neig, 1)

    with A.uselinopparams(*params), M.uselinopparams(
            *mparams) if M is not None else dummy_context_manager():
        # set up the initial guess
        V = _set_initial_v(config["v_init"].lower(),
                           dtype,
                           device,
                           bcast_dims,
                           na,
                           nguess,
                           M=M,
                           mparams=mparams)  # (*BAM, na, nguess)
        # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess)

        # estimating the lowest eigenvalues
        eig_est, rms_eig = _estimate_eigvals(A,
                                             neig,
                                             mode,
                                             bcast_dims=bcast_dims,
                                             na=na,
                                             ntest=20,
                                             dtype=V.dtype,
                                             device=V.device)

        best_resid = float("inf")
        AV = A.mm(V)
        for i in range(max_niter):
            VT = V.transpose(-2, -1)  # (*BAM,nguess,na)
            # Can be optimized by saving AV from the previous iteration and only
            # operate AV for the new V. This works because the old V has already
            # been orthogonalized, so it will stay the same
            # AV = A.mm(V) # (*BAM,na,nguess)
            T = torch.matmul(VT, AV)  # (*BAM,nguess,nguess)

            # eigvals are sorted from the lowest
            # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess)
            eigvalT, eigvecT = torch.symeig(T, eigenvectors=True)
            eigvalT, eigvecT = _take_eigpairs(
                eigvalT, eigvecT, neig,
                mode)  # (*BAM, neig) and (*BAM, nguess, neig)

            # calculate the eigenvectors of A
            eigvecA = torch.matmul(V, eigvecT)  # (*BAM, na, neig)

            # calculate the residual
            AVs = torch.matmul(AV, eigvecT)  # (*BAM, na, neig)
            LVs = eigvalT.unsqueeze(-2) * eigvecA  # (*BAM, na, neig)
            if M is not None:
                LVs = M.mm(LVs)
            resid = AVs - LVs  # (*BAM, na, neig)

            # print information and check convergence
            max_resid = resid.abs().max()
            if prev_eigvalT is not None:
                deigval = eigvalT - prev_eigvalT
                max_deigval = deigval.abs().max()
                if verbose:
                    print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \
                          (i+1, nguess, max_resid, max_deigval))

            if max_resid < best_resid:
                best_resid = max_resid
                best_eigvals = eigvalT
                best_eigvecs = eigvecA
            if max_resid < min_eps:
                break
            if AV.shape[-1] == AV.shape[-2]:
                break
            prev_eigvalT = eigvalT

            # apply the preconditioner
            # initial guess of the eigenvalues are actually help really much
            if not shift_is_eigvalT:
                z = eig_est  # (*BAM,neig)
            else:
                z = eigvalT  # (*BAM,neig)
            # if A.is_precond_set():
            #     t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig)
            # else:
            t = -resid  # (*BAM, na, neig)

            # set the estimate of the eigenvalues
            if not shift_is_eigvalT:
                eigvalT_pred = eigvalT + torch.einsum(
                    '...ae,...ae->...e', eigvecA, A.mm(t))  # (*BAM, neig)
                diff_eigvalT = (eigvalT - eigvalT_pred)  # (*BAM, neig)
                if diff_eigvalT.abs().max() < rms_eig * 1e-2:
                    shift_is_eigvalT = True
                else:
                    change_idx = eig_est > eigvalT
                    next_value = eigvalT - 2 * diff_eigvalT
                    eig_est[change_idx] = next_value[change_idx]

            # orthogonalize t with the rest of the V
            t = to_fortran_order(t)
            Vnew = torch.cat((V, t), dim=-1)
            if Vnew.shape[-1] > Vnew.shape[-2]:
                Vnew = Vnew[..., :Vnew.shape[-2]]
            nadd = Vnew.shape[-1] - V.shape[-1]
            nguess = nguess + nadd
            if M is not None:
                MV_ = M.mm(Vnew)
                V, R = tallqr(Vnew, MV=MV_)
            else:
                V, R = tallqr(Vnew)
            AVnew = A.mm(V[..., -nadd:])  # (*BAM,na,nadd)
            AVnew = to_fortran_order(AVnew)
            AV = torch.cat((AV, AVnew), dim=-1)

    eigvals = best_eigvals  # (*BAM, neig)
    eigvecs = best_eigvecs  # (*BAM, na, neig)
    return eigvals, eigvecs
Esempio n. 18
0
def davidson(A,
             params,
             neig,
             mode,
             M=None,
             mparams=[],
             max_niter=1000,
             nguess=None,
             v_init="randn",
             max_addition=None,
             min_eps=1e-6,
             verbose=False,
             **unused):
    """
    Using Davidson method for large sparse matrix eigendecomposition [1]_.

    Arguments
    ---------
    max_niter: int
        Maximum number of iterations
    nguess: int or None
        The number of initial guess of the eigenvectors
        If None, set to ``neig``.
    v_init: str
        Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``)
    max_addition: int or None
        Maximum number of new guesses to be added to the collected vectors.
        If None, set to ``neig``.
    min_eps: float
        Minimum residual error to be stopped
    verbose: bool
        Option to be verbose

    References
    ----------
    .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems"
           http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf
    """
    # TODO: optimize for large linear operator and strict min_eps
    # Ideas:
    # (1) use better strategy to get the estimate on eigenvalues
    # (2) use restart strategy

    if nguess is None:
        nguess = neig
    if max_addition is None:
        max_addition = neig

    # get the shape of the transformation
    na = A.shape[-1]
    if M is None:
        bcast_dims = A.shape[:-2]
    else:
        bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2])
    dtype = A.dtype
    device = A.device

    # TODO: A to use params
    prev_eigvals = None
    prev_eigvalT = None
    stop_reason = "max_niter"
    shift_is_eigvalT = False
    idx = torch.arange(neig).unsqueeze(-1)  # (neig, 1)

    with A.uselinopparams(*params), M.uselinopparams(
            *mparams) if M is not None else dummy_context_manager():
        # set up the initial guess
        V = _set_initial_v(v_init.lower(),
                           dtype,
                           device,
                           bcast_dims,
                           na,
                           nguess,
                           M=M,
                           mparams=mparams)  # (*BAM, na, nguess)
        # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess)

        # estimating the lowest eigenvalues
        eig_est, rms_eig = _estimate_eigvals(A,
                                             neig,
                                             mode,
                                             bcast_dims=bcast_dims,
                                             na=na,
                                             ntest=20,
                                             dtype=V.dtype,
                                             device=V.device)

        best_resid = float("inf")
        AV = A.mm(V)
        for i in range(max_niter):
            VT = V.transpose(-2, -1)  # (*BAM,nguess,na)
            # Can be optimized by saving AV from the previous iteration and only
            # operate AV for the new V. This works because the old V has already
            # been orthogonalized, so it will stay the same
            # AV = A.mm(V) # (*BAM,na,nguess)
            T = torch.matmul(VT, AV)  # (*BAM,nguess,nguess)

            # eigvals are sorted from the lowest
            # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess)
            eigvalT, eigvecT = torch.symeig(T, eigenvectors=True)
            eigvalT, eigvecT = _take_eigpairs(
                eigvalT, eigvecT, neig,
                mode)  # (*BAM, neig) and (*BAM, nguess, neig)

            # calculate the eigenvectors of A
            eigvecA = torch.matmul(V, eigvecT)  # (*BAM, na, neig)

            # calculate the residual
            AVs = torch.matmul(AV, eigvecT)  # (*BAM, na, neig)
            LVs = eigvalT.unsqueeze(-2) * eigvecA  # (*BAM, na, neig)
            if M is not None:
                LVs = M.mm(LVs)
            resid = AVs - LVs  # (*BAM, na, neig)

            # print information and check convergence
            max_resid = resid.abs().max()
            if prev_eigvalT is not None:
                deigval = eigvalT - prev_eigvalT
                max_deigval = deigval.abs().max()
                if verbose:
                    print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \
                          (i+1, nguess, max_resid, max_deigval))

            if max_resid < best_resid:
                best_resid = max_resid
                best_eigvals = eigvalT
                best_eigvecs = eigvecA
            if max_resid < min_eps:
                break
            if AV.shape[-1] == AV.shape[-2]:
                break
            prev_eigvalT = eigvalT

            # apply the preconditioner
            # initial guess of the eigenvalues are actually help really much
            if not shift_is_eigvalT:
                z = eig_est  # (*BAM,neig)
            else:
                z = eigvalT  # (*BAM,neig)
            # if A.is_precond_set():
            #     t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig)
            # else:
            t = -resid  # (*BAM, na, neig)

            # set the estimate of the eigenvalues
            if not shift_is_eigvalT:
                eigvalT_pred = eigvalT + torch.einsum(
                    '...ae,...ae->...e', eigvecA, A.mm(t))  # (*BAM, neig)
                diff_eigvalT = (eigvalT - eigvalT_pred)  # (*BAM, neig)
                if diff_eigvalT.abs().max() < rms_eig * 1e-2:
                    shift_is_eigvalT = True
                else:
                    change_idx = eig_est > eigvalT
                    next_value = eigvalT - 2 * diff_eigvalT
                    eig_est[change_idx] = next_value[change_idx]

            # orthogonalize t with the rest of the V
            t = to_fortran_order(t)
            Vnew = torch.cat((V, t), dim=-1)
            if Vnew.shape[-1] > Vnew.shape[-2]:
                Vnew = Vnew[..., :Vnew.shape[-2]]
            nadd = Vnew.shape[-1] - V.shape[-1]
            nguess = nguess + nadd
            if M is not None:
                MV_ = M.mm(Vnew)
                V, R = tallqr(Vnew, MV=MV_)
            else:
                V, R = tallqr(Vnew)
            AVnew = A.mm(V[..., -nadd:])  # (*BAM,na,nadd)
            AVnew = to_fortran_order(AVnew)
            AV = torch.cat((AV, AVnew), dim=-1)

    eigvals = best_eigvals  # (*BAM, neig)
    eigvecs = best_eigvecs  # (*BAM, na, neig)
    return eigvals, eigvecs