コード例 #1
0
ファイル: solve.py プロジェクト: mfkasim1/xitorch
def exactsolve(A: LinearOperator, B: torch.Tensor,
               E: Union[torch.Tensor, None], M: Union[LinearOperator, None]):
    """
    Solve the linear equation by contructing the full matrix of LinearOperators.

    Warnings
    --------
    * As this method construct the linear operators explicitly, it might requires
      a large memory.
    """
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    if E is None:
        Amatrix = A.fullmatrix()  # (*BA, na, na)
        x, _ = torch.solve(B, Amatrix)  # (*BAB, na, ncols)
    elif M is None:
        Amatrix = A.fullmatrix()
        x = _solve_ABE(Amatrix, B, E)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, na, na)
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, na, na)
        Linv = torch.inverse(L)  # (*BM, na, na)
        LinvT = Linv.transpose(-2, -1)  # (*BM, na, na)
        A2 = torch.matmul(Linv, A.mm(LinvT))  # (*BAM, na, na)
        B2 = torch.matmul(Linv, B)  # (*BBM, na, ncols)

        X2 = _solve_ABE(A2, B2, E)  # (*BABEM, na, ncols)
        x = torch.matmul(LinvT, X2)  # (*BABEM, na, ncols)
    return x
コード例 #2
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_AEM(dtype, device, abeshape, mshape, method):
    torch.manual_seed(seed)
    na = abeshape[-1]
    ashape = abeshape
    bshape = abeshape
    eshape = abeshape
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    eshape = [*eshape[:-2], ncols]
    xshape = list(
        get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1],
                         mshape[:-2])) + [na, ncols]
    fwd_options = {"method": method, "min_eps": 1e-9}
    bck_options = {
        "method": method
    }  # exactsolve at backward just to test the forward solve

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5
    bmat = torch.rand(bshape, dtype=dtype, device=device)
    emat = torch.rand(eshape, dtype=dtype, device=device)
    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

    amat = amat.requires_grad_()
    mmat = mmat.requires_grad_()
    bmat = bmat.requires_grad_()
    emat = emat.requires_grad_()

    def solvefcn(amat, mmat, bmat, emat):
        mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        x = solve(A=alinop,
                  B=bmat,
                  E=emat,
                  M=mlinop,
                  **fwd_options,
                  bck_options=bck_options)
        return x

    x = solvefcn(amat, mmat, bmat, emat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    mxe = LinearOperator.m(mmat).mm(
        torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)))
    y = ax - mxe
    assert torch.allclose(y, bmat)

    # gradient checker
    if checkgrad:
        gradcheck(solvefcn, (amat, mmat, bmat, emat))
        gradgradcheck(solvefcn, (amat, mmat, bmat, emat))
コード例 #3
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
 def lsymeig_fcn(amat, mmat):
     # symmetrize
     amat = (amat + amat.transpose(-2, -1)) * 0.5
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat, is_hermitian=True)
     mlinop = LinearOperator.m(mmat, is_hermitian=True)
     eigvals_, eigvecs_ = lsymeig(alinop,
                                  M=mlinop,
                                  neig=neig,
                                  **fwd_options)
     return eigvals_, eigvecs_
コード例 #4
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
 def solvefcn(amat, mmat, bmat, emat):
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat)
     mlinop = LinearOperator.m(mmat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               M=mlinop,
               **fwd_options,
               bck_options=bck_options)
     return x
コード例 #5
0
def test_lsymeig_AM():
    torch.manual_seed(seed)
    shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)]
    # only 2 of methods, because both gradient implementations are covered
    methods = ["exacteig", "custom_exacteig"]
    dtype = torch.float64
    for ashape, mshape, method in itertools.product(shapes, shapes, methods):
        mata = torch.rand(ashape, dtype=dtype)
        matm = torch.rand(mshape, dtype=dtype) + torch.eye(
            mshape[-1], dtype=dtype)  # make sure it's not singular
        mata = mata + mata.transpose(-2, -1)
        matm = matm + matm.transpose(-2, -1)
        mata = mata.requires_grad_()
        matm = matm.requires_grad_()
        linopa = LinearOperator.m(mata, True)
        linopm = LinearOperator.m(matm, True)
        fwd_options = {"method": method}

        na = ashape[-1]
        bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
        for neig in [2, ashape[-1]]:
            eigvals, eigvecs = lsymeig(linopa,
                                       M=linopm,
                                       neig=neig,
                                       **fwd_options)  # eigvals: (..., neig)
            assert list(eigvals.shape) == list([*bshape, neig])
            assert list(eigvecs.shape) == list([*bshape, na, neig])

            ax = linopa.mm(eigvecs)
            mxe = linopm.mm(
                torch.matmul(eigvecs,
                             torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
            assert torch.allclose(ax, mxe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == ashape[-1]:

                def lsymeig_fcn(amat, mmat):
                    # symmetrize
                    amat = (amat + amat.transpose(-2, -1)) * 0.5
                    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    mlinop = LinearOperator.m(mmat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 M=mlinop,
                                                 neig=neig,
                                                 **fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mata, matm))
                gradgradcheck(lsymeig_fcn, (mata, matm))
コード例 #6
0
def test_linop_add():
    mat = torch.randn((2, 3, 2))
    linop1 = LinOp1(mat)
    linop2 = LinOp2(mat + 1)

    # test using non-matrix linop
    c = linop1 + linop2
    assert torch.allclose(c.fullmatrix(), 2 * mat + 1)

    # test using matrix linear operator
    m1 = LinearOperator.m(mat)
    m2 = LinearOperator.m(mat + 1)
    m12 = m1 + m2
    assert torch.allclose(m12.fullmatrix(), 2 * mat + 1)
コード例 #7
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_lsymeig_mismatch_err(dtype, device):
    torch.manual_seed(seed)
    mat1 = torch.rand((3, 3), dtype=dtype, device=device)
    mat2 = torch.rand((2, 2), dtype=dtype, device=device)
    mat1 = mat1 + mat1.transpose(-2, -1)
    mat2 = mat2 + mat2.transpose(-2, -1)
    linop1 = LinearOperator.m(mat1, True)
    linop2 = LinearOperator.m(mat2, True)

    try:
        res = lsymeig(linop1, M=linop2)
        assert False, "A RuntimeError must be raised if A & M shape are mismatch"
    except RuntimeError:
        pass
コード例 #8
0
def test_linop_sub():
    # test the behaviour of subtraction of LinearOperators
    mat = torch.randn((2, 3, 2))
    linop1 = LinOp1(mat)
    linop2 = LinOp2(-mat + 1)

    # test using non-matrix linop
    c = linop1 - linop2
    assert torch.allclose(c.fullmatrix(), 2 * mat - 1)

    # test using matrix linear operator
    m1 = LinearOperator.m(mat)
    m2 = LinearOperator.m(-mat + 1)
    m12 = m1 - m2
    assert torch.allclose(m12.fullmatrix(), 2 * mat - 1)
コード例 #9
0
def test_solve_A_methods():
    torch.manual_seed(seed)
    na = 3
    dtype = torch.float64
    ashape = (na, na)
    bshape = (2, na, na)
    methods = ["gmres", "lbfgs"]
    for method in methods:
        fwd_options = {"method": method}

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

        amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1],
                                                           dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        amat = amat + amat.transpose(-2, -1)

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()

        def solvefcn(amat, bmat):
            alinop = LinearOperator.m(amat)
            x = solve(A=alinop, B=bmat, **fwd_options)
            return x

        x = solvefcn(amat, bmat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        assert torch.allclose(ax, bmat)
コード例 #10
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
 def svd_fcn(amat, only_s=False):
     alinop = LinearOperator.m(amat, is_hermitian=False)
     u_, s_, vh_ = svd(alinop, k=k, **fwd_options)
     if only_s:
         return s_
     else:
         return u_, s_, vh_
コード例 #11
0
def exacteig(A: LinearOperator, neig: int,
             mode: str, M: Optional[LinearOperator]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Eigendecomposition using explicit matrix construction.
    No additional option for this method.

    Warnings
    --------
    * As this method construct the linear operators explicitly, it might requires
      a large memory.
    """
    Amatrix = A.fullmatrix()  # (*BA, q, q)
    if M is None:
        evals, evecs = torch.symeig(Amatrix, eigenvectors=True)  # (*BA, q), (*BA, q, q)
        return _take_eigpairs(evals, evecs, neig, mode)
    else:
        Mmatrix = M.fullmatrix()  # (*BM, q, q)

        # M decomposition to make A symmetric
        # it is done this way to make it numerically stable in avoiding
        # complex eigenvalues for (near-)degenerate case
        L = torch.cholesky(Mmatrix, upper=False)  # (*BM, q, q)
        Linv = torch.inverse(L)  # (*BM, q, q)
        LinvT = Linv.transpose(-2, -1)  # (*BM, q, q)
        A2 = torch.matmul(Linv, torch.matmul(Amatrix, LinvT))  # (*BAM, q, q)

        # calculate the eigenvalues and eigenvectors
        # (the eigvecs are normalized in M-space)
        evals, evecs = torch.symeig(A2, eigenvectors=True)  # (*BAM, q, q)
        evals, evecs = _take_eigpairs(evals, evecs, neig, mode)  # (*BAM, neig) and (*BAM, q, neig)
        evecs = torch.matmul(LinvT, evecs)
        return evals, evecs
コード例 #12
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_lsymeig_A(dtype, device, shape, method):
    torch.manual_seed(seed)
    mat1 = torch.rand(shape, dtype=dtype, device=device)
    mat1 = mat1 + mat1.transpose(-2, -1)
    mat1 = mat1.requires_grad_()
    linop1 = LinearOperator.m(mat1, True)
    fwd_options = {"method": method}

    for neig in [2, shape[-1]]:
        eigvals, eigvecs = lsymeig(
            linop1, neig=neig,
            **fwd_options)  # eigvals: (..., neig), eigvecs: (..., na, neig)
        assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig])
        assert list(eigvals.shape) == list([*linop1.shape[:-2], neig])

        ax = linop1.mm(eigvecs)
        xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))
        assert torch.allclose(ax, xe)

        # only perform gradcheck if neig is full, to reduce the computational cost
        if neig == shape[-1]:

            def lsymeig_fcn(amat):
                amat = (amat + amat.transpose(-2, -1)) * 0.5  # symmetrize
                alinop = LinearOperator.m(amat, is_hermitian=True)
                eigvals_, eigvecs_ = lsymeig(alinop, neig=neig, **fwd_options)
                return eigvals_, eigvecs_

            gradcheck(lsymeig_fcn, (mat1, ))
            gradgradcheck(lsymeig_fcn, (mat1, ))
コード例 #13
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_svd_A(dtype, device, shape, method):
    torch.manual_seed(seed)
    mat1 = torch.rand(shape, dtype=dtype, device=device)
    mat1 = mat1.requires_grad_()
    linop1 = LinearOperator.m(mat1, is_hermitian=False)
    fwd_options = {"method": method}

    min_mn = min(shape[-1], shape[-2])
    for k in [min_mn]:
        u, s, vh = svd(
            linop1, k=k,
            **fwd_options)  # u: (..., m, k), s: (..., k), vh: (..., k, n)
        assert list(u.shape) == list([*linop1.shape[:-1], k])
        assert list(s.shape) == list([*linop1.shape[:-2], k])
        assert list(vh.shape) == list(
            [*linop1.shape[:-2], k, linop1.shape[-1]])

        keye = torch.zeros((*shape[:-2], k, k), dtype=dtype, device=device) + \
            torch.eye(k, dtype=dtype, device=device)
        assert torch.allclose(u.transpose(-2, -1) @ u, keye)
        assert torch.allclose(vh @ vh.transpose(-2, -1), keye)
        if k == min_mn:
            assert torch.allclose(mat1, u @ torch.diag_embed(s) @ vh)

        def svd_fcn(amat, only_s=False):
            alinop = LinearOperator.m(amat, is_hermitian=False)
            u_, s_, vh_ = svd(alinop, k=k, **fwd_options)
            if only_s:
                return s_
            else:
                return u_, s_, vh_

        gradcheck(svd_fcn, (mat1, ))
        gradgradcheck(svd_fcn, (mat1, True))
コード例 #14
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
 def solvefcn(amat, bmat, emat):
     alinop = LinearOperator.m(amat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               **fwd_options,
               bck_options=bck_options)
     return x
コード例 #15
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_lsymeig_nonhermit_err(dtype, device):
    torch.manual_seed(seed)
    mat = torch.rand((3, 3), dtype=dtype, device=device)
    linop = LinearOperator.m(mat, False)
    linop2 = LinearOperator.m(mat + mat.transpose(-2, -1), True)

    try:
        res = lsymeig(linop)
        assert False, "A RuntimeError must be raised if the A linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass

    try:
        res = lsymeig(linop2, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass
コード例 #16
0
def test_linop_mat_hermit_err():
    torch.manual_seed(100)
    mat = torch.rand(3, 3)
    mat2 = torch.rand(3, 4)
    msg = "Expecting a RuntimeError for non-symmetric matrix "\
          "indicated as a Hermitian"

    try:
        LinearOperator.m(mat, is_hermitian=True)
        assert False, msg
    except RuntimeError:
        pass

    try:
        LinearOperator.m(mat2, is_hermitian=True)
        assert False, msg
    except RuntimeError:
        pass
コード例 #17
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_lsymeig_AM(dtype, device, ashape, mshape, method):
    torch.manual_seed(seed)
    mata = torch.rand(ashape, dtype=dtype, device=device)
    matm = torch.rand(mshape, dtype=dtype, device=device) + \
        torch.eye(mshape[-1], dtype=dtype, device=device)  # make sure it's not singular
    mata = mata + mata.transpose(-2, -1)
    matm = matm + matm.transpose(-2, -1)
    mata = mata.requires_grad_()
    matm = matm.requires_grad_()
    linopa = LinearOperator.m(mata, True)
    linopm = LinearOperator.m(matm, True)
    fwd_options = {"method": method}

    na = ashape[-1]
    bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
    for neig in [2, ashape[-1]]:
        eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig,
                                   **fwd_options)  # eigvals: (..., neig)
        assert list(eigvals.shape) == list([*bshape, neig])
        assert list(eigvecs.shape) == list([*bshape, na, neig])

        ax = linopa.mm(eigvecs)
        mxe = linopm.mm(
            torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
        assert torch.allclose(ax, mxe)

        # only perform gradcheck if neig is full, to reduce the computational cost
        if neig == ashape[-1]:

            def lsymeig_fcn(amat, mmat):
                # symmetrize
                amat = (amat + amat.transpose(-2, -1)) * 0.5
                mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                alinop = LinearOperator.m(amat, is_hermitian=True)
                mlinop = LinearOperator.m(mmat, is_hermitian=True)
                eigvals_, eigvecs_ = lsymeig(alinop,
                                             M=mlinop,
                                             neig=neig,
                                             **fwd_options)
                return eigvals_, eigvecs_

            gradcheck(lsymeig_fcn, (mata, matm))
            gradgradcheck(lsymeig_fcn, (mata, matm))
コード例 #18
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_nonsquare_err(dtype, device):
    torch.manual_seed(seed)
    mat = torch.rand((3, 2), dtype=dtype, device=device)
    mat2 = torch.rand((3, 3), dtype=dtype, device=device)
    linop = LinearOperator.m(mat)
    linop2 = LinearOperator.m(mat2)
    B = torch.rand((3, 1), dtype=dtype, device=device)

    try:
        res = solve(linop, B)
        assert False, "A RuntimeError must be raised if the A linear operator in solve not square"
    except RuntimeError:
        pass

    try:
        res = solve(linop2, B, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in solve is not square"
    except RuntimeError:
        pass
コード例 #19
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_AEM_methods(dtype, device, method):
    torch.manual_seed(seed)
    na = 100
    nc = na // 2

    amshape = (na, na)
    eshape = (nc, )
    bshape = (2, na, nc)
    options = {
        "scipy_gmres": {},
        "broyden1": {},
        "cg": {
            "rtol":
            1e-8  # stringent rtol required to meet the torch.allclose tols
        },
        "bicgstab": {},
    }[method]
    fwd_options = {"method": method, **options}

    amat = torch.rand(amshape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(amshape[-1], dtype=dtype, device=device)
    mmat = torch.rand(amshape, dtype=dtype, device=device) * 0.05 + \
        torch.eye(amshape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1
    emat = torch.rand(eshape, dtype=dtype, device=device) * 0.1
    amat = (amat + amat.transpose(-2, -1)) * 0.5
    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()
    emat = emat.requires_grad_()

    def solvefcn(amat, bmat, emat, mmat):
        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options)
        return x

    x = solvefcn(amat, bmat, emat, mmat)
    ax = LinearOperator.m(amat).mm(x)
    mxe = LinearOperator.m(mmat).mm(x) @ torch.diag_embed(emat)
    assert torch.allclose(ax - mxe, bmat)
コード例 #20
0
 def setup(self, minmaxeival, n):
     seed = 123
     ncols = 50
     torch.manual_seed(seed)
     min_eival, max_eival = minmaxeival
     A = create_random_square_matrix(n,
                                     is_hermitian=True,
                                     min_eival=min_eival,
                                     max_eival=max_eival,
                                     seed=seed)
     self.A = LinearOperator.m(A, is_hermitian=True)
コード例 #21
0
 def setup(self, is_hermitian, minmaxeival, n):
     seed = 123
     ncols = 50
     torch.manual_seed(seed)
     min_eival, max_eival = minmaxeival
     A = create_random_square_matrix(n,
                                     is_hermitian=is_hermitian,
                                     min_eival=min_eival,
                                     max_eival=max_eival,
                                     seed=seed)
     self.A = LinearOperator.m(A, is_hermitian=is_hermitian)
     X = torch.randn(n, ncols, dtype=A.dtype)
     self.B = self.A.mm(X)
コード例 #22
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_mismatch_err(dtype, device):
    torch.manual_seed(seed)
    shapes = [
        #   A      B      M
        ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"),
        ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"),
    ]
    for (ashape, bshape, mshape), msg in shapes:
        amat = torch.rand(ashape, dtype=dtype, device=device)
        bmat = torch.rand(bshape, dtype=dtype, device=device)
        mmat = torch.rand(mshape, dtype=dtype, device=device) + \
            torch.eye(mshape[-1], dtype=dtype, device=device)
        amat = amat + amat.transpose(-2, -1)
        mmat = mmat + mmat.transpose(-2, -1)

        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        try:
            res = solve(alinop, B=bmat, M=mlinop)
            assert False, "A RuntimeError must be raised if %s" % msg
        except RuntimeError:
            pass
コード例 #23
0
    def get_loss(a, matA, matM, P2):
        # get the orthogonal vector for the eigenvectors
        P, _ = torch.qr(matA)
        PM, _ = torch.qr(matM)

        # line up the eigenvalues
        b = torch.cat((a[:2], a[1:2], a[2:], a[2:]))

        # construct the matrix
        diag = torch.diag_embed(b)
        A = torch.matmul(torch.matmul(P.T, diag), P)
        M = torch.matmul(PM.T, PM)
        Alinop = LinearOperator.m(A)
        Mlinop = LinearOperator.m(M)

        eivals, eivecs = symeig(Alinop,
                                M=Mlinop,
                                neig=neig,
                                method="custom_exacteig",
                                bck_options=bck_options)
        U = eivecs[:, 1:3]  # the degenerate eigenvectors

        loss = torch.einsum("rc,rc->", torch.matmul(P2, U), U)
        return loss
コード例 #24
0
def test_linop_mul():
    # test the behaviour of multiplication of LinearOperator with a number
    mat = torch.randn((2, 3, 2))
    linop1 = LinOp1(mat)
    linop2 = LinearOperator.m(mat)

    for f1 in [2, 4.0]:
        print(f1)
        # test using non-matrix linop multiplier
        c11l = linop1 * f1
        c11r = f1 * linop1
        # test using matrix linop multiplier
        c12l = linop2 * f1
        c12r = f1 * linop2
        assert torch.allclose(c11l.fullmatrix(), f1 * mat)
        assert torch.allclose(c11r.fullmatrix(), f1 * mat)
        assert torch.allclose(c12l.fullmatrix(), f1 * mat)
        assert torch.allclose(c12r.fullmatrix(), f1 * mat)
コード例 #25
0
def test_solve_A():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    # custom exactsolve to check the backward implementation
    methods = ["exactsolve", "custom_exactsolve"]
    # hermitian check here to make sure the gradient propagated symmetrically
    hermits = [False, True]
    for ashape, bshape, method, hermit in itertools.product(
            shapes, shapes, methods, hermits):
        print(ashape, bshape, method, hermit)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        if hermit:
            amat = (amat + amat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()

        def solvefcn(amat, bmat):
            # is_hermitian=hermit is required to force the hermitian status in numerical gradient
            alinop = LinearOperator.m(amat, is_hermitian=hermit)
            x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        assert torch.allclose(ax, bmat)

        if checkgrad:
            gradcheck(solvefcn, (amat, bmat))
            gradgradcheck(solvefcn, (amat, bmat))
コード例 #26
0
def test_solve_AE():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    methods = ["exactsolve", "custom_exactsolve"
               ]  # custom exactsolve to check the backward implementation
    dtype = torch.float64
    for ashape, bshape, eshape, method in itertools.product(
            shapes, shapes, shapes, methods):
        print(ashape, bshape, eshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                       eshape[:-1])) + [na, ncols]
        fwd_options = {"method": method}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, bmat, emat):
            alinop = LinearOperator.m(amat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      **fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
        assert torch.allclose(ax - xe, bmat)
コード例 #27
0
    def get_loss(a, mat):
        # get the orthogonal vector for the eigenvectors
        P, _ = torch.qr(mat)

        # line up the eigenvalues
        b = torch.cat((a[:2], a[1:2], a[2:], a[2:]))

        # construct the matrix
        diag = torch.diag_embed(b)
        A = torch.matmul(torch.matmul(P.T, diag), P)
        Alinop = LinearOperator.m(A)

        eivals, eivecs = symeig(Alinop,
                                neig=neig,
                                method="custom_exacteig",
                                bck_options=bck_options)
        U = eivecs[:, :3]  # the degenerate eigenvectors are in 1,2
        loss = torch.sum(U**4)
        return loss
コード例 #28
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_A_methods(dtype, device, method):
    torch.manual_seed(seed)
    na = 100
    ashape = (na, na)
    bshape = (2, na, na)
    options = {
        "scipy_gmres": {},
        "broyden1": {},
        "cg": {
            "rtol":
            1e-8  # stringent rtol required to meet the torch.allclose tols
        },
        "bicgstab": {
            "rtol": 1e-8,
        },
    }[method]
    fwd_options = {"method": method, **options}

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1
    amat = (amat + amat.transpose(-2, -1)) * 0.5

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def solvefcn(amat, bmat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop, B=bmat, **fwd_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    assert torch.allclose(ax, bmat)
コード例 #29
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_A(dtype, device, ashape, bshape, method, hermit):
    torch.manual_seed(seed)
    na = ashape[-1]
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
    fwd_options = {"method": method, "min_eps": 1e-9}
    bck_options = {"method": method}

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def prepare(amat):
        if hermit:
            return (amat + amat.transpose(-2, -1)) * 0.5
        return amat

    def solvefcn(amat, bmat):
        # is_hermitian=hermit is required to force the hermitian status in numerical gradient
        alinop = LinearOperator.m(prepare(amat), is_hermitian=hermit)
        x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(prepare(amat)).mm(x)
    assert torch.allclose(ax, bmat)

    if checkgrad:
        gradcheck(solvefcn, (amat, bmat))
        gradgradcheck(solvefcn, (amat, bmat))
コード例 #30
0
ファイル: test_linop_fcns.py プロジェクト: mfkasim1/xitorch
def test_solve_AE(dtype, device, ashape, bshape, eshape, method):
    torch.manual_seed(seed)
    na = ashape[-1]
    checkgrad = method.endswith("exactsolve")

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    eshape = [*eshape[:-2], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                   eshape[:-1])) + [na, ncols]
    fwd_options = {"method": method}
    bck_options = {"method": method}

    amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \
        torch.eye(ashape[-1], dtype=dtype, device=device)
    bmat = torch.rand(bshape, dtype=dtype, device=device)
    emat = torch.rand(eshape, dtype=dtype, device=device)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()
    emat = emat.requires_grad_()

    def solvefcn(amat, bmat, emat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop,
                  B=bmat,
                  E=emat,
                  **fwd_options,
                  bck_options=bck_options)
        return x

    x = solvefcn(amat, bmat, emat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
    assert torch.allclose(ax - xe, bmat)