Exemple #1
0
 def lsymeig_fcn(amat, mmat):
     # symmetrize
     amat = (amat + amat.transpose(-2, -1)) * 0.5
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat, is_hermitian=True)
     mlinop = LinearOperator.m(mmat, is_hermitian=True)
     eigvals_, eigvecs_ = lsymeig(alinop,
                                  M=mlinop,
                                  neig=neig,
                                  fwd_options=fwd_options)
     return eigvals_, eigvecs_
Exemple #2
0
 def solvefcn(amat, mmat, bmat, emat):
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat)
     mlinop = LinearOperator.m(mmat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               M=mlinop,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
Exemple #3
0
def test_lsymeig_mismatch_err():
    torch.manual_seed(seed)
    mat1 = torch.rand((3, 3))
    mat2 = torch.rand((2, 2))
    mat1 = mat1 + mat1.transpose(-2, -1)
    mat2 = mat2 + mat2.transpose(-2, -1)
    linop1 = LinearOperator.m(mat1, True)
    linop2 = LinearOperator.m(mat2, True)

    try:
        res = lsymeig(linop1, M=linop2)
        assert False, "A RuntimeError must be raised if A & M shape are mismatch"
    except RuntimeError:
        pass
Exemple #4
0
def test_lsymeig_AM():
    torch.manual_seed(seed)
    shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)]
    methods = ["exacteig", "davidson"]
    dtype = torch.float64
    for ashape, mshape, method in itertools.product(shapes, shapes, methods):
        mata = torch.rand(ashape, dtype=dtype)
        matm = torch.rand(mshape, dtype=dtype) + torch.eye(
            mshape[-1], dtype=dtype)  # make sure it's not singular
        mata = mata + mata.transpose(-2, -1)
        matm = matm + matm.transpose(-2, -1)
        mata = mata.requires_grad_()
        matm = matm.requires_grad_()
        linopa = LinearOperator.m(mata, True)
        linopm = LinearOperator.m(matm, True)
        fwd_options = {"method": method}

        na = ashape[-1]
        bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
        for neig in [2, ashape[-1]]:
            eigvals, eigvecs = lsymeig(
                linopa, M=linopm, neig=neig,
                fwd_options=fwd_options)  # eigvals: (..., neig)
            assert list(eigvals.shape) == list([*bshape, neig])
            assert list(eigvecs.shape) == list([*bshape, na, neig])

            ax = linopa.mm(eigvecs)
            mxe = linopm.mm(
                torch.matmul(eigvecs,
                             torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
            assert torch.allclose(ax, mxe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == ashape[-1]:

                def lsymeig_fcn(amat, mmat):
                    # symmetrize
                    amat = (amat + amat.transpose(-2, -1)) * 0.5
                    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    mlinop = LinearOperator.m(mmat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 M=mlinop,
                                                 neig=neig,
                                                 fwd_options=fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mata, matm))
                gradgradcheck(lsymeig_fcn, (mata, matm))
Exemple #5
0
def test_solve_A_gmres():
    torch.manual_seed(seed)
    na = 3
    dtype = torch.float64
    ashape = (na, na)
    bshape = (2, na, na)
    fwd_options = {"method": "gmres"}

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

    amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype)
    bmat = torch.rand(bshape, dtype=dtype)
    amat = amat + amat.transpose(-2, -1)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def solvefcn(amat, bmat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop, B=bmat, fwd_options=fwd_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    assert torch.allclose(ax, bmat)

    gradcheck(solvefcn, (amat, bmat))
    gradgradcheck(solvefcn, (amat, bmat))
Exemple #6
0
def exacteig(A: LinearOperator, neig: Union[int, None], mode: str,
             M: Union[LinearOperator, None]):
    Amatrix = A.fullmatrix()  # (*BA, q, q)
    if neig is None:
        neig = A.shape[-1]
    if M is None:
        evals, evecs = torch.symeig(Amatrix,
                                    eigenvectors=True)  # (*BA, q), (*BA, q, q)
        return _take_eigpairs(evals, evecs, neig, mode)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, q, q)

        # M decomposition to make A symmetric
        # it is done this way to make it numerically stable in avoiding
        # complex eigenvalues for (near-)degenerate case
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, q, q)
        Linv = torch.inverse(L)  # (*BM, q, q)
        LinvT = Linv.transpose(-2, -1)  # (*BM, q, q)
        A2 = torch.matmul(Linv, torch.matmul(Amatrix, LinvT))  # (*BAM, q, q)

        # calculate the eigenvalues and eigenvectors
        # (the eigvecs are normalized in M-space)
        evals, evecs = torch.symeig(A2, eigenvectors=True)  # (*BAM, q, q)
        evals, evecs = _take_eigpairs(evals, evecs, neig,
                                      mode)  # (*BAM, neig) and (*BAM, q, neig)
        evecs = torch.matmul(LinvT, evecs)
        return evals, evecs
Exemple #7
0
 def solvefcn(amat, bmat, emat):
     alinop = LinearOperator.m(amat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
Exemple #8
0
 def solvefcn(amat, bmat):
     # is_hermitian=hermit is required to force the hermitian status in numerical gradient
     alinop = LinearOperator.m(amat, is_hermitian=hermit)
     x = solve(A=alinop,
               B=bmat,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
Exemple #9
0
def test_lsymeig_nonhermit_err():
    torch.manual_seed(seed)
    mat = torch.rand((3, 3))
    linop = LinearOperator.m(mat, False)
    linop2 = LinearOperator.m(mat + mat.transpose(-2, -1), True)

    try:
        res = lsymeig(linop)
        assert False, "A RuntimeError must be raised if the A linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass

    try:
        res = lsymeig(linop2, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass
Exemple #10
0
def test_solve_nonsquare_err():
    torch.manual_seed(seed)
    mat = torch.rand((3, 2))
    mat2 = torch.rand((3, 3))
    linop = LinearOperator.m(mat)
    linop2 = LinearOperator.m(mat2)
    B = torch.rand(3, 1)

    try:
        res = solve(linop, B)
        assert False, "A RuntimeError must be raised if the A linear operator in solve not square"
    except RuntimeError:
        pass

    try:
        res = solve(linop2, B, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in solve is not square"
    except RuntimeError:
        pass
Exemple #11
0
def exactsolve(A: LinearOperator, B: torch.Tensor,
               E: Union[torch.Tensor, None], M: Union[LinearOperator, None]):
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    if E is None:
        Amatrix = A.fullmatrix()  # (*BA, na, na)
        x, _ = torch.solve(B, Amatrix)  # (*BAB, na, ncols)
    elif M is None:
        Amatrix = A.fullmatrix()
        x = _solve_ABE(Amatrix, B, E)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, na, na)
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, na, na)
        Linv = torch.inverse(L)  # (*BM, na, na)
        LinvT = Linv.transpose(-2, -1)  # (*BM, na, na)
        A2 = torch.matmul(Linv, A.mm(LinvT))  # (*BAM, na, na)
        B2 = torch.matmul(Linv, B)  # (*BBM, na, ncols)

        X2 = _solve_ABE(A2, B2, E)  # (*BABEM, na, ncols)
        x = torch.matmul(LinvT, X2)  # (*BABEM, na, ncols)
    return x
Exemple #12
0
def test_solve_mismatch_err():
    torch.manual_seed(seed)
    shapes = [
        #   A      B      M
        ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"),
        ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"),
    ]
    dtype = torch.float64
    for (ashape, bshape, mshape), msg in shapes:
        amat = torch.rand(ashape, dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        mmat = torch.rand(mshape, dtype=dtype) + torch.eye(mshape[-1],
                                                           dtype=dtype)
        amat = amat + amat.transpose(-2, -1)
        mmat = mmat + mmat.transpose(-2, -1)

        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        try:
            res = solve(alinop, B=bmat, M=mlinop)
            assert False, "A RuntimeError must be raised if %s" % msg
        except RuntimeError:
            pass
Exemple #13
0
def test_solve_A():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    # custom exactsolve to check the backward implementation
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"]
    # hermitian check here to make sure the gradient propagated symmetrically
    hermits = [False, True]
    for ashape, bshape, method, hermit in itertools.product(
            shapes, shapes, methods, hermits):
        print(ashape, bshape, method, hermit)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        if hermit:
            amat = (amat + amat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()

        def solvefcn(amat, bmat):
            # is_hermitian=hermit is required to force the hermitian status in numerical gradient
            alinop = LinearOperator.m(amat, is_hermitian=hermit)
            x = solve(A=alinop,
                      B=bmat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        assert torch.allclose(ax, bmat)

        if checkgrad:
            gradcheck(solvefcn, (amat, bmat))
            gradgradcheck(solvefcn, (amat, bmat))
Exemple #14
0
def test_solve_AE():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"
               ]  # custom exactsolve to check the backward implementation
    dtype = torch.float64
    for ashape, bshape, eshape, method in itertools.product(
            shapes, shapes, shapes, methods):
        print(ashape, bshape, eshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                       eshape[:-1])) + [na, ncols]
        fwd_options = {"method": method}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, bmat, emat):
            alinop = LinearOperator.m(amat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
        assert torch.allclose(ax - xe, bmat)
Exemple #15
0
def test_lsymeig_A():
    torch.manual_seed(seed)
    shapes = [(4, 4), (2, 4, 4), (2, 3, 4, 4)]
    methods = ["exacteig", "davidson"]
    for shape, method in itertools.product(shapes, methods):
        mat1 = torch.rand(shape, dtype=torch.float64)
        mat1 = mat1 + mat1.transpose(-2, -1)
        mat1 = mat1.requires_grad_()
        linop1 = LinearOperator.m(mat1, True)
        fwd_options = {"method": method}

        for neig in [2, shape[-1]]:
            eigvals, eigvecs = lsymeig(
                linop1, neig=neig, fwd_options=fwd_options
            )  # eigvals: (..., neig), eigvecs: (..., na, neig)
            assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig])
            assert list(eigvals.shape) == list([*linop1.shape[:-2], neig])

            ax = linop1.mm(eigvecs)
            xe = torch.matmul(eigvecs,
                              torch.diag_embed(eigvals, dim1=-2, dim2=-1))
            assert torch.allclose(ax, xe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == shape[-1]:

                def lsymeig_fcn(amat):
                    amat = (amat + amat.transpose(-2, -1)) * 0.5  # symmetrize
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 neig=neig,
                                                 fwd_options=fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mat1, ))
                gradgradcheck(lsymeig_fcn, (mat1, ))
Exemple #16
0
def solve(A: LinearOperator,
          B: torch.Tensor,
          E: Union[torch.Tensor, None] = None,
          M: Union[LinearOperator, None] = None,
          posdef=False,
          fwd_options: Mapping[str, Any] = {},
          bck_options: Mapping[str, Any] = {}):
    """
    Performing iterative method to solve the equation AX=B or
    AX-MXE=B, where E is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying A to a tensor X with shape (...,na,ncols).
    The applied E are not necessarily identical for each column.

    Arguments
    ---------
    * A: xitorch.LinearOperator instance with shape (*BA, na, na)
        A function that takes an input X and produce the vectors in the same
        space as B.
    * B: torch.tensor (*BB, na, ncols)
        The tensor on the right hand side.
    * E: torch.tensor (*BE, ncols) or None
        If not None, it will solve AX-MXE = B. Otherwise, it just solves
        AX = B and M is ignored. E would be applied to every column.
    * M: xitorch.LinearOperator instance (*BM, na, na) or None
        The transformation on the E side. If E is None,
        then this argument is ignored. I E is not None and M is None, then M=I.
        This LinearOperator must be Hermitian.
    * fwd_options: dict
        Options of the iterative solver in the forward calculation
    * bck_options: dict
        Options of the iterative solver in the backward calculation
    """
    assert_runtime(A.shape[-1] == A.shape[-2],
                   "The linear operator A must have a square shape")
    assert_runtime(
        A.shape[-1] == B.shape[-2],
        "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2],
                       "The linear operator M must have a square shape")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be a Hermitian matrix")
    if E is not None:
        assert_runtime(
            E.shape[-1] == B.shape[-1],
            "The last dimension of E & B must match (E: %s, B: %s)" %
            (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn(
            "M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if "method" not in fwd_options or fwd_options["method"].lower(
    ) == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(A, B, E, M, posdef, fwd_options,
                                    bck_options, na, *params, *mparams)
Exemple #17
0
def symeig(A: LinearOperator,
           neig: Union[int, None] = None,
           mode: str = "lowest",
           M: Union[LinearOperator, None] = None,
           fwd_options: Mapping[str, Any] = {},
           bck_options: Mapping[str, Any] = {}):
    """
    Obtain `neig` lowest eigenvalues and eigenvectors of a linear operator.
    If M is specified, it solve the eigendecomposition Ax = eMx.

    Arguments
    ---------
    * A: xitorch.LinearOperator hermitian instance with shape (*BA, q, q)
        The linear module object on which the eigenpairs are constructed.
    * neig: int or None
        The number of eigenpairs to be retrieved. If None, all eigenpairs are
        retrieved
    * mode: str
        "lowest" or "uppermost"/"uppest". If "lowest", it will take the lowest
        `neig` eigenpairs. If "uppest", it will take the uppermost `neig`.
    * M: xitorch.LinearOperator hermitian instance with shape (*BM, q, q) or None
        The transformation on the right hand side. If None, then M=I.
    * fwd_options: dict with str as key
        Eigendecomposition iterative algorithm options.
    * bck_options: dict with str as key
        Conjugate gradient options to calculate the gradient in
        backpropagation calculation.

    Returns
    -------
    * eigvals: (*BAM, neig)
    * eigvecs: (*BAM, na, neig)
        The lowest eigenvalues and eigenvectors, where *BAM are the broadcasted
        shape of *BA and *BM.
    """
    assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian")
    if M is not None:
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be Hermitian")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
    mode = mode.lower()
    if mode == "uppermost":
        mode = "uppest"

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if "method" not in fwd_options or fwd_options["method"].lower(
    ) == "exacteig":
        return exacteig(A, neig, mode, M)
    else:
        # get the unique parameters of A & M
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return symeig_torchfcn.apply(A, neig, mode, M, fwd_options,
                                     bck_options, na, *params, *mparams)
Exemple #18
0
def test_solve_AEM():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"]
    for abeshape, mshape, method in itertools.product(shapes, shapes, methods):
        ashape = abeshape
        bshape = abeshape
        eshape = abeshape
        print(abeshape, mshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(
            get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1],
                             mshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {
            "method": method
        }  # exactsolve at backward just to test the forward solve

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        mmat = torch.rand(mshape, dtype=dtype) * 0.1 + torch.eye(
            mshape[-1], dtype=dtype) * 0.5
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)
        mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        mmat = mmat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, mmat, bmat, emat):
            mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
            alinop = LinearOperator.m(amat)
            mlinop = LinearOperator.m(mmat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      M=mlinop,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, mmat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        mxe = LinearOperator.m(mmat).mm(
            torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)))
        y = ax - mxe
        assert torch.allclose(y, bmat)

        # gradient checker
        if checkgrad:
            gradcheck(solvefcn, (amat, mmat, bmat, emat))
            gradgradcheck(solvefcn, (amat, mmat, bmat, emat))