def test_solve_AEM(dtype, device, abeshape, mshape, method): torch.manual_seed(seed) na = abeshape[-1] ashape = abeshape bshape = abeshape eshape = abeshape checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list( get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1], mshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = { "method": method } # exactsolve at backward just to test the forward solve amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \ torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5 bmat = torch.rand(bshape, dtype=dtype, device=device) emat = torch.rand(eshape, dtype=dtype, device=device) mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() mmat = mmat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, mmat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) mxe = LinearOperator.m(mmat).mm( torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))) y = ax - mxe assert torch.allclose(y, bmat) # gradient checker if checkgrad: gradcheck(solvefcn, (amat, mmat, bmat, emat)) gradgradcheck(solvefcn, (amat, mmat, bmat, emat))
def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options, bck_options=bck_options) return x
def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_
def test_lsymeig_AM(): torch.manual_seed(seed) shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)] # only 2 of methods, because both gradient implementations are covered methods = ["exacteig", "custom_exacteig"] dtype = torch.float64 for ashape, mshape, method in itertools.product(shapes, shapes, methods): mata = torch.rand(ashape, dtype=dtype) matm = torch.rand(mshape, dtype=dtype) + torch.eye( mshape[-1], dtype=dtype) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() linopa = LinearOperator.m(mata, True) linopm = LinearOperator.m(matm, True) fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) for neig in [2, ashape[-1]]: eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig, **fwd_options) # eigvals: (..., neig) assert list(eigvals.shape) == list([*bshape, neig]) assert list(eigvecs.shape) == list([*bshape, na, neig]) ax = linopa.mm(eigvecs) mxe = linopm.mm( torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))) assert torch.allclose(ax, mxe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == ashape[-1]: def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mata, matm)) gradgradcheck(lsymeig_fcn, (mata, matm))
def test_linop_add(): mat = torch.randn((2, 3, 2)) linop1 = LinOp1(mat) linop2 = LinOp2(mat + 1) # test using non-matrix linop c = linop1 + linop2 assert torch.allclose(c.fullmatrix(), 2 * mat + 1) # test using matrix linear operator m1 = LinearOperator.m(mat) m2 = LinearOperator.m(mat + 1) m12 = m1 + m2 assert torch.allclose(m12.fullmatrix(), 2 * mat + 1)
def test_lsymeig_mismatch_err(dtype, device): torch.manual_seed(seed) mat1 = torch.rand((3, 3), dtype=dtype, device=device) mat2 = torch.rand((2, 2), dtype=dtype, device=device) mat1 = mat1 + mat1.transpose(-2, -1) mat2 = mat2 + mat2.transpose(-2, -1) linop1 = LinearOperator.m(mat1, True) linop2 = LinearOperator.m(mat2, True) try: res = lsymeig(linop1, M=linop2) assert False, "A RuntimeError must be raised if A & M shape are mismatch" except RuntimeError: pass
def test_linop_sub(): # test the behaviour of subtraction of LinearOperators mat = torch.randn((2, 3, 2)) linop1 = LinOp1(mat) linop2 = LinOp2(-mat + 1) # test using non-matrix linop c = linop1 - linop2 assert torch.allclose(c.fullmatrix(), 2 * mat - 1) # test using matrix linear operator m1 = LinearOperator.m(mat) m2 = LinearOperator.m(-mat + 1) m12 = m1 - m2 assert torch.allclose(m12.fullmatrix(), 2 * mat - 1)
def test_solve_A_methods(): torch.manual_seed(seed) na = 3 dtype = torch.float64 ashape = (na, na) bshape = (2, na, na) methods = ["gmres", "lbfgs"] for method in methods: fwd_options = {"method": method} ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) amat = amat + amat.transpose(-2, -1) amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, **fwd_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat)
def test_svd_A(dtype, device, shape, method): torch.manual_seed(seed) mat1 = torch.rand(shape, dtype=dtype, device=device) mat1 = mat1.requires_grad_() linop1 = LinearOperator.m(mat1, is_hermitian=False) fwd_options = {"method": method} min_mn = min(shape[-1], shape[-2]) for k in [min_mn]: u, s, vh = svd( linop1, k=k, **fwd_options) # u: (..., m, k), s: (..., k), vh: (..., k, n) assert list(u.shape) == list([*linop1.shape[:-1], k]) assert list(s.shape) == list([*linop1.shape[:-2], k]) assert list(vh.shape) == list( [*linop1.shape[:-2], k, linop1.shape[-1]]) keye = torch.zeros((*shape[:-2], k, k), dtype=dtype, device=device) + \ torch.eye(k, dtype=dtype, device=device) assert torch.allclose(u.transpose(-2, -1) @ u, keye) assert torch.allclose(vh @ vh.transpose(-2, -1), keye) if k == min_mn: assert torch.allclose(mat1, u @ torch.diag_embed(s) @ vh) def svd_fcn(amat, only_s=False): alinop = LinearOperator.m(amat, is_hermitian=False) u_, s_, vh_ = svd(alinop, k=k, **fwd_options) if only_s: return s_ else: return u_, s_, vh_ gradcheck(svd_fcn, (mat1, )) gradgradcheck(svd_fcn, (mat1, True))
def svd_fcn(amat, only_s=False): alinop = LinearOperator.m(amat, is_hermitian=False) u_, s_, vh_ = svd(alinop, k=k, **fwd_options) if only_s: return s_ else: return u_, s_, vh_
def test_lsymeig_A(dtype, device, shape, method): torch.manual_seed(seed) mat1 = torch.rand(shape, dtype=dtype, device=device) mat1 = mat1 + mat1.transpose(-2, -1) mat1 = mat1.requires_grad_() linop1 = LinearOperator.m(mat1, True) fwd_options = {"method": method} for neig in [2, shape[-1]]: eigvals, eigvecs = lsymeig( linop1, neig=neig, **fwd_options) # eigvals: (..., neig), eigvecs: (..., na, neig) assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig]) assert list(eigvals.shape) == list([*linop1.shape[:-2], neig]) ax = linop1.mm(eigvecs) xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)) assert torch.allclose(ax, xe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == shape[-1]: def lsymeig_fcn(amat): amat = (amat + amat.transpose(-2, -1)) * 0.5 # symmetrize alinop = LinearOperator.m(amat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mat1, )) gradgradcheck(lsymeig_fcn, (mat1, ))
def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, **fwd_options, bck_options=bck_options) return x
def test_lsymeig_nonhermit_err(dtype, device): torch.manual_seed(seed) mat = torch.rand((3, 3), dtype=dtype, device=device) linop = LinearOperator.m(mat, False) linop2 = LinearOperator.m(mat + mat.transpose(-2, -1), True) try: res = lsymeig(linop) assert False, "A RuntimeError must be raised if the A linear operator in lsymeig is not Hermitian" except RuntimeError: pass try: res = lsymeig(linop2, M=linop) assert False, "A RuntimeError must be raised if the M linear operator in lsymeig is not Hermitian" except RuntimeError: pass
def test_linop_mat_hermit_err(): torch.manual_seed(100) mat = torch.rand(3, 3) mat2 = torch.rand(3, 4) msg = "Expecting a RuntimeError for non-symmetric matrix "\ "indicated as a Hermitian" try: LinearOperator.m(mat, is_hermitian=True) assert False, msg except RuntimeError: pass try: LinearOperator.m(mat2, is_hermitian=True) assert False, msg except RuntimeError: pass
def test_lsymeig_AM(dtype, device, ashape, mshape, method): torch.manual_seed(seed) mata = torch.rand(ashape, dtype=dtype, device=device) matm = torch.rand(mshape, dtype=dtype, device=device) + \ torch.eye(mshape[-1], dtype=dtype, device=device) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() linopa = LinearOperator.m(mata, True) linopm = LinearOperator.m(matm, True) fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) for neig in [2, ashape[-1]]: eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig, **fwd_options) # eigvals: (..., neig) assert list(eigvals.shape) == list([*bshape, neig]) assert list(eigvecs.shape) == list([*bshape, na, neig]) ax = linopa.mm(eigvecs) mxe = linopm.mm( torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))) assert torch.allclose(ax, mxe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == ashape[-1]: def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mata, matm)) gradgradcheck(lsymeig_fcn, (mata, matm))
def test_solve_nonsquare_err(dtype, device): torch.manual_seed(seed) mat = torch.rand((3, 2), dtype=dtype, device=device) mat2 = torch.rand((3, 3), dtype=dtype, device=device) linop = LinearOperator.m(mat) linop2 = LinearOperator.m(mat2) B = torch.rand((3, 1), dtype=dtype, device=device) try: res = solve(linop, B) assert False, "A RuntimeError must be raised if the A linear operator in solve not square" except RuntimeError: pass try: res = solve(linop2, B, M=linop) assert False, "A RuntimeError must be raised if the M linear operator in solve is not square" except RuntimeError: pass
def test_solve_AEM_methods(dtype, device, method): torch.manual_seed(seed) na = 100 nc = na // 2 amshape = (na, na) eshape = (nc, ) bshape = (2, na, nc) options = { "scipy_gmres": {}, "broyden1": {}, "cg": { "rtol": 1e-8 # stringent rtol required to meet the torch.allclose tols }, "bicgstab": {}, }[method] fwd_options = {"method": method, **options} amat = torch.rand(amshape, dtype=dtype, device=device) * 0.1 + \ torch.eye(amshape[-1], dtype=dtype, device=device) mmat = torch.rand(amshape, dtype=dtype, device=device) * 0.05 + \ torch.eye(amshape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1 emat = torch.rand(eshape, dtype=dtype, device=device) * 0.1 amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat, mmat): alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options) return x x = solvefcn(amat, bmat, emat, mmat) ax = LinearOperator.m(amat).mm(x) mxe = LinearOperator.m(mmat).mm(x) @ torch.diag_embed(emat) assert torch.allclose(ax - mxe, bmat)
def setup(self, minmaxeival, n): seed = 123 ncols = 50 torch.manual_seed(seed) min_eival, max_eival = minmaxeival A = create_random_square_matrix(n, is_hermitian=True, min_eival=min_eival, max_eival=max_eival, seed=seed) self.A = LinearOperator.m(A, is_hermitian=True)
def setup(self, is_hermitian, minmaxeival, n): seed = 123 ncols = 50 torch.manual_seed(seed) min_eival, max_eival = minmaxeival A = create_random_square_matrix(n, is_hermitian=is_hermitian, min_eival=min_eival, max_eival=max_eival, seed=seed) self.A = LinearOperator.m(A, is_hermitian=is_hermitian) X = torch.randn(n, ncols, dtype=A.dtype) self.B = self.A.mm(X)
def test_solve_mismatch_err(dtype, device): torch.manual_seed(seed) shapes = [ # A B M ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"), ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"), ] for (ashape, bshape, mshape), msg in shapes: amat = torch.rand(ashape, dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) mmat = torch.rand(mshape, dtype=dtype, device=device) + \ torch.eye(mshape[-1], dtype=dtype, device=device) amat = amat + amat.transpose(-2, -1) mmat = mmat + mmat.transpose(-2, -1) alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) try: res = solve(alinop, B=bmat, M=mlinop) assert False, "A RuntimeError must be raised if %s" % msg except RuntimeError: pass
def get_loss(a, matA, matM, P2): # get the orthogonal vector for the eigenvectors P, _ = torch.qr(matA) PM, _ = torch.qr(matM) # line up the eigenvalues b = torch.cat((a[:2], a[1:2], a[2:], a[2:])) # construct the matrix diag = torch.diag_embed(b) A = torch.matmul(torch.matmul(P.T, diag), P) M = torch.matmul(PM.T, PM) Alinop = LinearOperator.m(A) Mlinop = LinearOperator.m(M) eivals, eivecs = symeig(Alinop, M=Mlinop, neig=neig, method="custom_exacteig", bck_options=bck_options) U = eivecs[:, 1:3] # the degenerate eigenvectors loss = torch.einsum("rc,rc->", torch.matmul(P2, U), U) return loss
def test_linop_mul(): # test the behaviour of multiplication of LinearOperator with a number mat = torch.randn((2, 3, 2)) linop1 = LinOp1(mat) linop2 = LinearOperator.m(mat) for f1 in [2, 4.0]: print(f1) # test using non-matrix linop multiplier c11l = linop1 * f1 c11r = f1 * linop1 # test using matrix linop multiplier c12l = linop2 * f1 c12r = f1 * linop2 assert torch.allclose(c11l.fullmatrix(), f1 * mat) assert torch.allclose(c11r.fullmatrix(), f1 * mat) assert torch.allclose(c12l.fullmatrix(), f1 * mat) assert torch.allclose(c12r.fullmatrix(), f1 * mat)
def test_solve_A(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] dtype = torch.float64 # custom exactsolve to check the backward implementation methods = ["exactsolve", "custom_exactsolve"] # hermitian check here to make sure the gradient propagated symmetrically hermits = [False, True] for ashape, bshape, method, hermit in itertools.product( shapes, shapes, methods, hermits): print(ashape, bshape, method, hermit) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) if hermit: amat = (amat + amat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(amat, is_hermitian=hermit) x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat) if checkgrad: gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_solve_AE(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] methods = ["exactsolve", "custom_exactsolve" ] # custom exactsolve to check the backward implementation dtype = torch.float64 for ashape, bshape, eshape, method in itertools.product( shapes, shapes, shapes, methods): print(ashape, bshape, eshape, method) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1])) + [na, ncols] fwd_options = {"method": method} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) emat = torch.rand(eshape, dtype=dtype) amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)) assert torch.allclose(ax - xe, bmat)
def get_loss(a, mat): # get the orthogonal vector for the eigenvectors P, _ = torch.qr(mat) # line up the eigenvalues b = torch.cat((a[:2], a[1:2], a[2:], a[2:])) # construct the matrix diag = torch.diag_embed(b) A = torch.matmul(torch.matmul(P.T, diag), P) Alinop = LinearOperator.m(A) eivals, eivecs = symeig(Alinop, neig=neig, method="custom_exacteig", bck_options=bck_options) U = eivecs[:, :3] # the degenerate eigenvectors are in 1,2 loss = torch.sum(U**4) return loss
def test_solve_A_methods(dtype, device, method): torch.manual_seed(seed) na = 100 ashape = (na, na) bshape = (2, na, na) options = { "scipy_gmres": {}, "broyden1": {}, "cg": { "rtol": 1e-8 # stringent rtol required to meet the torch.allclose tols }, "bicgstab": { "rtol": 1e-8, }, }[method] fwd_options = {"method": method, **options} ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1 amat = (amat + amat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, **fwd_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat)
def test_solve_A(dtype, device, ashape, bshape, method, hermit): torch.manual_seed(seed) na = ashape[-1] checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) amat = amat.requires_grad_() bmat = bmat.requires_grad_() def prepare(amat): if hermit: return (amat + amat.transpose(-2, -1)) * 0.5 return amat def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(prepare(amat), is_hermitian=hermit) x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(prepare(amat)).mm(x) assert torch.allclose(ax, bmat) if checkgrad: gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_lsymeig_A(): torch.manual_seed(seed) shapes = [(4, 4), (2, 4, 4), (2, 3, 4, 4)] # only 2 of methods, because both gradient implementations are covered methods = ["exacteig", "custom_exacteig"] for shape, method in itertools.product(shapes, methods): mat1 = torch.rand(shape, dtype=torch.float64) mat1 = mat1 + mat1.transpose(-2, -1) mat1 = mat1.requires_grad_() linop1 = LinearOperator.m(mat1, True) fwd_options = {"method": method} for neig in [2, shape[-1]]: eigvals, eigvecs = lsymeig( linop1, neig=neig, ** fwd_options) # eigvals: (..., neig), eigvecs: (..., na, neig) assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig]) assert list(eigvals.shape) == list([*linop1.shape[:-2], neig]) ax = linop1.mm(eigvecs) xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)) assert torch.allclose(ax, xe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == shape[-1]: def lsymeig_fcn(amat): amat = (amat + amat.transpose(-2, -1)) * 0.5 # symmetrize alinop = LinearOperator.m(amat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mat1, )) gradgradcheck(lsymeig_fcn, (mat1, ))
def test_solve_AE(dtype, device, ashape, bshape, eshape, method): torch.manual_seed(seed) na = ashape[-1] checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1])) + [na, ncols] fwd_options = {"method": method} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) emat = torch.rand(eshape, dtype=dtype, device=device) amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)) assert torch.allclose(ax - xe, bmat)
def test_linop_repr(): dtype = torch.float32 device = torch.device("cpu") mat1 = torch.randn((2, 3, 2), dtype=dtype, device=device) a = LinOp1(mat1) arepr = ["LinearOperator", "LinOp1", "(2, 3, 2)", "float32", "cpu"] _assert_str_contains(a.__repr__(), arepr) b = a.H brepr = arepr + ["AdjointLinearOperator", "(2, 2, 3)"] _assert_str_contains(b.__repr__(), brepr) c = b.matmul(a) crepr = brepr + ["MatmulLinearOperator", "(2, 2, 2)"] _assert_str_contains(c.__repr__(), crepr) d = LinearOperator.m(mat1) drepr = ["MatrixLinearOperator", "(2, 3, 2)"] _assert_str_contains(d.__repr__(), drepr) e = d + a erepr = ["AddLinearOperator", "(2, 3, 2)"] _assert_str_contains(d.__repr__(), drepr)