Esempio n. 1
0
def test_symeig_A_large_methods():
    torch.manual_seed(seed)

    class ALarge(LinearOperator):
        def __init__(self, shape, dtype):
            super(ALarge, self).__init__(shape, is_hermitian=True, dtype=dtype)
            na = shape[-1]
            self.b = torch.arange(na, dtype=dtype).repeat(*shape[:-2], 1)

        def _mv(self, x):
            # x: (*BX, na)
            xb = x * self.b
            xsmall = x * 1e-3
            xp1 = torch.roll(xsmall, shifts=1, dims=-1)
            xm1 = torch.roll(xsmall, shifts=-1, dims=-1)
            return xb + xp1 + xm1

        def _getparamnames(self, prefix=""):
            return [prefix + "b"]

    na = 1000
    shapes = [(na, na), (2, na, na), (2, 3, na, na)]
    # list the methods here
    methods = ["davidson"]
    modes = ["uppermost", "lowest"]
    neig = 2
    dtype = torch.float64
    for shape, method, mode in itertools.product(shapes, methods, modes):
        linop1 = ALarge(shape, dtype=dtype)
        fwd_options = {"method": method, "min_eps": 1e-8}

        eigvals, eigvecs = symeig(
            linop1, mode=mode, neig=neig,
            **fwd_options)  # eigvals: (..., neig), eigvecs: (..., na, neig)

        # the matrix's eigenvalues will be around arange(na)
        if mode == "lowest":
            assert (eigvals < neig * 2).all()
        elif mode == "uppermost":
            assert (eigvals > na - neig * 2).all()

        assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig])
        assert list(eigvals.shape) == list([*linop1.shape[:-2], neig])

        ax = linop1.mm(eigvecs)
        xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))
        assert torch.allclose(ax, xe)
Esempio n. 2
0
    def get_loss(a, mat):
        # get the orthogonal vector for the eigenvectors
        P, _ = torch.qr(mat)

        # line up the eigenvalues
        b = torch.cat((a[:2], a[1:2], a[2:], a[2:]))

        # construct the matrix
        diag = torch.diag_embed(b)
        A = torch.matmul(torch.matmul(P.T, diag), P)
        Alinop = LinearOperator.m(A)

        eivals, eivecs = symeig(Alinop,
                                neig=neig,
                                method="custom_exacteig",
                                bck_options=bck_options)
        U = eivecs[:, :3]  # the degenerate eigenvectors are in 1,2
        loss = torch.sum(U**4)
        return loss
Esempio n. 3
0
    def get_loss(a, matA, matM, P2):
        # get the orthogonal vector for the eigenvectors
        P, _ = torch.qr(matA)
        PM, _ = torch.qr(matM)

        # line up the eigenvalues
        b = torch.cat((a[:2], a[1:2], a[2:], a[2:]))

        # construct the matrix
        diag = torch.diag_embed(b)
        A = torch.matmul(torch.matmul(P.T, diag), P)
        M = torch.matmul(PM.T, PM)
        Alinop = LinearOperator.m(A)
        Mlinop = LinearOperator.m(M)

        eivals, eivecs = symeig(Alinop,
                                M=Mlinop,
                                neig=neig,
                                method="custom_exacteig",
                                bck_options=bck_options)
        U = eivecs[:, 1:3]  # the degenerate eigenvectors

        loss = torch.einsum("rc,rc->", torch.matmul(P2, U), U)
        return loss