コード例 #1
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
 def solvefcn(amat, mmat, bmat, emat):
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat)
     mlinop = LinearOperator.m(mmat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               M=mlinop,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
コード例 #2
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
 def lsymeig_fcn(amat, mmat):
     # symmetrize
     amat = (amat + amat.transpose(-2, -1)) * 0.5
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat, is_hermitian=True)
     mlinop = LinearOperator.m(mmat, is_hermitian=True)
     eigvals_, eigvecs_ = lsymeig(alinop,
                                  M=mlinop,
                                  neig=neig,
                                  fwd_options=fwd_options)
     return eigvals_, eigvecs_
コード例 #3
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_lsymeig_mismatch_err():
    torch.manual_seed(seed)
    mat1 = torch.rand((3, 3))
    mat2 = torch.rand((2, 2))
    mat1 = mat1 + mat1.transpose(-2, -1)
    mat2 = mat2 + mat2.transpose(-2, -1)
    linop1 = LinearOperator.m(mat1, True)
    linop2 = LinearOperator.m(mat2, True)

    try:
        res = lsymeig(linop1, M=linop2)
        assert False, "A RuntimeError must be raised if A & M shape are mismatch"
    except RuntimeError:
        pass
コード例 #4
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_lsymeig_AM():
    torch.manual_seed(seed)
    shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)]
    methods = ["exacteig", "davidson"]
    dtype = torch.float64
    for ashape, mshape, method in itertools.product(shapes, shapes, methods):
        mata = torch.rand(ashape, dtype=dtype)
        matm = torch.rand(mshape, dtype=dtype) + torch.eye(
            mshape[-1], dtype=dtype)  # make sure it's not singular
        mata = mata + mata.transpose(-2, -1)
        matm = matm + matm.transpose(-2, -1)
        mata = mata.requires_grad_()
        matm = matm.requires_grad_()
        linopa = LinearOperator.m(mata, True)
        linopm = LinearOperator.m(matm, True)
        fwd_options = {"method": method}

        na = ashape[-1]
        bshape = get_bcasted_dims(ashape[:-2], mshape[:-2])
        for neig in [2, ashape[-1]]:
            eigvals, eigvecs = lsymeig(
                linopa, M=linopm, neig=neig,
                fwd_options=fwd_options)  # eigvals: (..., neig)
            assert list(eigvals.shape) == list([*bshape, neig])
            assert list(eigvecs.shape) == list([*bshape, na, neig])

            ax = linopa.mm(eigvecs)
            mxe = linopm.mm(
                torch.matmul(eigvecs,
                             torch.diag_embed(eigvals, dim1=-2, dim2=-1)))
            assert torch.allclose(ax, mxe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == ashape[-1]:

                def lsymeig_fcn(amat, mmat):
                    # symmetrize
                    amat = (amat + amat.transpose(-2, -1)) * 0.5
                    mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    mlinop = LinearOperator.m(mmat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 M=mlinop,
                                                 neig=neig,
                                                 fwd_options=fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mata, matm))
                gradgradcheck(lsymeig_fcn, (mata, matm))
コード例 #5
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_A_gmres():
    torch.manual_seed(seed)
    na = 3
    dtype = torch.float64
    ashape = (na, na)
    bshape = (2, na, na)
    fwd_options = {"method": "gmres"}

    ncols = bshape[-1] - 1
    bshape = [*bshape[:-1], ncols]
    xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]

    amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype)
    bmat = torch.rand(bshape, dtype=dtype)
    amat = amat + amat.transpose(-2, -1)

    amat = amat.requires_grad_()
    bmat = bmat.requires_grad_()

    def solvefcn(amat, bmat):
        alinop = LinearOperator.m(amat)
        x = solve(A=alinop, B=bmat, fwd_options=fwd_options)
        return x

    x = solvefcn(amat, bmat)
    assert list(x.shape) == xshape

    ax = LinearOperator.m(amat).mm(x)
    assert torch.allclose(ax, bmat)

    gradcheck(solvefcn, (amat, bmat))
    gradgradcheck(solvefcn, (amat, bmat))
コード例 #6
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
 def solvefcn(amat, bmat, emat):
     alinop = LinearOperator.m(amat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
コード例 #7
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
 def solvefcn(amat, bmat):
     # is_hermitian=hermit is required to force the hermitian status in numerical gradient
     alinop = LinearOperator.m(amat, is_hermitian=hermit)
     x = solve(A=alinop,
               B=bmat,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
コード例 #8
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_lsymeig_nonhermit_err():
    torch.manual_seed(seed)
    mat = torch.rand((3, 3))
    linop = LinearOperator.m(mat, False)
    linop2 = LinearOperator.m(mat + mat.transpose(-2, -1), True)

    try:
        res = lsymeig(linop)
        assert False, "A RuntimeError must be raised if the A linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass

    try:
        res = lsymeig(linop2, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in lsymeig is not Hermitian"
    except RuntimeError:
        pass
コード例 #9
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_nonsquare_err():
    torch.manual_seed(seed)
    mat = torch.rand((3, 2))
    mat2 = torch.rand((3, 3))
    linop = LinearOperator.m(mat)
    linop2 = LinearOperator.m(mat2)
    B = torch.rand(3, 1)

    try:
        res = solve(linop, B)
        assert False, "A RuntimeError must be raised if the A linear operator in solve not square"
    except RuntimeError:
        pass

    try:
        res = solve(linop2, B, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in solve is not square"
    except RuntimeError:
        pass
コード例 #10
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_mismatch_err():
    torch.manual_seed(seed)
    shapes = [
        #   A      B      M
        ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"),
        ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"),
    ]
    dtype = torch.float64
    for (ashape, bshape, mshape), msg in shapes:
        amat = torch.rand(ashape, dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        mmat = torch.rand(mshape, dtype=dtype) + torch.eye(mshape[-1],
                                                           dtype=dtype)
        amat = amat + amat.transpose(-2, -1)
        mmat = mmat + mmat.transpose(-2, -1)

        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        try:
            res = solve(alinop, B=bmat, M=mlinop)
            assert False, "A RuntimeError must be raised if %s" % msg
        except RuntimeError:
            pass
コード例 #11
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_A():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    # custom exactsolve to check the backward implementation
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"]
    # hermitian check here to make sure the gradient propagated symmetrically
    hermits = [False, True]
    for ashape, bshape, method, hermit in itertools.product(
            shapes, shapes, methods, hermits):
        print(ashape, bshape, method, hermit)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        if hermit:
            amat = (amat + amat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()

        def solvefcn(amat, bmat):
            # is_hermitian=hermit is required to force the hermitian status in numerical gradient
            alinop = LinearOperator.m(amat, is_hermitian=hermit)
            x = solve(A=alinop,
                      B=bmat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        assert torch.allclose(ax, bmat)

        if checkgrad:
            gradcheck(solvefcn, (amat, bmat))
            gradgradcheck(solvefcn, (amat, bmat))
コード例 #12
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_AE():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"
               ]  # custom exactsolve to check the backward implementation
    dtype = torch.float64
    for ashape, bshape, eshape, method in itertools.product(
            shapes, shapes, shapes, methods):
        print(ashape, bshape, eshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2],
                                       eshape[:-1])) + [na, ncols]
        fwd_options = {"method": method}
        bck_options = {"method": method}

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)

        amat = amat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, bmat, emat):
            alinop = LinearOperator.m(amat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))
        assert torch.allclose(ax - xe, bmat)
コード例 #13
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_lsymeig_A():
    torch.manual_seed(seed)
    shapes = [(4, 4), (2, 4, 4), (2, 3, 4, 4)]
    methods = ["exacteig", "davidson"]
    for shape, method in itertools.product(shapes, methods):
        mat1 = torch.rand(shape, dtype=torch.float64)
        mat1 = mat1 + mat1.transpose(-2, -1)
        mat1 = mat1.requires_grad_()
        linop1 = LinearOperator.m(mat1, True)
        fwd_options = {"method": method}

        for neig in [2, shape[-1]]:
            eigvals, eigvecs = lsymeig(
                linop1, neig=neig, fwd_options=fwd_options
            )  # eigvals: (..., neig), eigvecs: (..., na, neig)
            assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig])
            assert list(eigvals.shape) == list([*linop1.shape[:-2], neig])

            ax = linop1.mm(eigvecs)
            xe = torch.matmul(eigvecs,
                              torch.diag_embed(eigvals, dim1=-2, dim2=-1))
            assert torch.allclose(ax, xe)

            # only perform gradcheck if neig is full, to reduce the computational cost
            if neig == shape[-1]:

                def lsymeig_fcn(amat):
                    amat = (amat + amat.transpose(-2, -1)) * 0.5  # symmetrize
                    alinop = LinearOperator.m(amat, is_hermitian=True)
                    eigvals_, eigvecs_ = lsymeig(alinop,
                                                 neig=neig,
                                                 fwd_options=fwd_options)
                    return eigvals_, eigvecs_

                gradcheck(lsymeig_fcn, (mat1, ))
                gradgradcheck(lsymeig_fcn, (mat1, ))
コード例 #14
0
ファイル: test_linop_fcns.py プロジェクト: AdityaJ7/xitorch
def test_solve_AEM():
    torch.manual_seed(seed)
    na = 2
    shapes = [(na, na), (2, na, na), (2, 1, na, na)]
    dtype = torch.float64
    methods = ["exactsolve", "custom_exactsolve", "lbfgs"]
    for abeshape, mshape, method in itertools.product(shapes, shapes, methods):
        ashape = abeshape
        bshape = abeshape
        eshape = abeshape
        print(abeshape, mshape, method)
        checkgrad = method.endswith("exactsolve")

        ncols = bshape[-1] - 1
        bshape = [*bshape[:-1], ncols]
        eshape = [*eshape[:-2], ncols]
        xshape = list(
            get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1],
                             mshape[:-2])) + [na, ncols]
        fwd_options = {"method": method, "min_eps": 1e-9}
        bck_options = {
            "method": method
        }  # exactsolve at backward just to test the forward solve

        amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1],
                                                                 dtype=dtype)
        mmat = torch.rand(mshape, dtype=dtype) * 0.1 + torch.eye(
            mshape[-1], dtype=dtype) * 0.5
        bmat = torch.rand(bshape, dtype=dtype)
        emat = torch.rand(eshape, dtype=dtype)
        mmat = (mmat + mmat.transpose(-2, -1)) * 0.5

        amat = amat.requires_grad_()
        mmat = mmat.requires_grad_()
        bmat = bmat.requires_grad_()
        emat = emat.requires_grad_()

        def solvefcn(amat, mmat, bmat, emat):
            mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
            alinop = LinearOperator.m(amat)
            mlinop = LinearOperator.m(mmat)
            x = solve(A=alinop,
                      B=bmat,
                      E=emat,
                      M=mlinop,
                      fwd_options=fwd_options,
                      bck_options=bck_options)
            return x

        x = solvefcn(amat, mmat, bmat, emat)
        assert list(x.shape) == xshape

        ax = LinearOperator.m(amat).mm(x)
        mxe = LinearOperator.m(mmat).mm(
            torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)))
        y = ax - mxe
        assert torch.allclose(y, bmat)

        # gradient checker
        if checkgrad:
            gradcheck(solvefcn, (amat, mmat, bmat, emat))
            gradgradcheck(solvefcn, (amat, mmat, bmat, emat))