def test_symeig_A_large_methods(): torch.manual_seed(seed) class ALarge(LinearOperator): def __init__(self, shape, dtype): super(ALarge, self).__init__(shape, is_hermitian=True, dtype=dtype) na = shape[-1] self.b = torch.arange(na, dtype=dtype).repeat(*shape[:-2], 1) def _mv(self, x): # x: (*BX, na) xb = x * self.b xsmall = x * 1e-3 xp1 = torch.roll(xsmall, shifts=1, dims=-1) xm1 = torch.roll(xsmall, shifts=-1, dims=-1) return xb + xp1 + xm1 def _getparamnames(self, prefix=""): return [prefix + "b"] na = 1000 shapes = [(na, na), (2, na, na), (2, 3, na, na)] # list the methods here methods = ["davidson"] modes = ["uppermost", "lowest"] neig = 2 dtype = torch.float64 for shape, method, mode in itertools.product(shapes, methods, modes): linop1 = ALarge(shape, dtype=dtype) fwd_options = {"method": method, "min_eps": 1e-8} eigvals, eigvecs = symeig( linop1, mode=mode, neig=neig, **fwd_options) # eigvals: (..., neig), eigvecs: (..., na, neig) # the matrix's eigenvalues will be around arange(na) if mode == "lowest": assert (eigvals < neig * 2).all() elif mode == "uppermost": assert (eigvals > na - neig * 2).all() assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig]) assert list(eigvals.shape) == list([*linop1.shape[:-2], neig]) ax = linop1.mm(eigvecs) xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)) assert torch.allclose(ax, xe)
def get_loss(a, mat): # get the orthogonal vector for the eigenvectors P, _ = torch.qr(mat) # line up the eigenvalues b = torch.cat((a[:2], a[1:2], a[2:], a[2:])) # construct the matrix diag = torch.diag_embed(b) A = torch.matmul(torch.matmul(P.T, diag), P) Alinop = LinearOperator.m(A) eivals, eivecs = symeig(Alinop, neig=neig, method="custom_exacteig", bck_options=bck_options) U = eivecs[:, :3] # the degenerate eigenvectors are in 1,2 loss = torch.sum(U**4) return loss
def get_loss(a, matA, matM, P2): # get the orthogonal vector for the eigenvectors P, _ = torch.qr(matA) PM, _ = torch.qr(matM) # line up the eigenvalues b = torch.cat((a[:2], a[1:2], a[2:], a[2:])) # construct the matrix diag = torch.diag_embed(b) A = torch.matmul(torch.matmul(P.T, diag), P) M = torch.matmul(PM.T, PM) Alinop = LinearOperator.m(A) Mlinop = LinearOperator.m(M) eivals, eivecs = symeig(Alinop, M=Mlinop, neig=neig, method="custom_exacteig", bck_options=bck_options) U = eivecs[:, 1:3] # the degenerate eigenvectors loss = torch.einsum("rc,rc->", torch.matmul(P2, U), U) return loss