def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, fwd_options=fwd_options) return eigvals_, eigvecs_
def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, fwd_options=fwd_options, bck_options=bck_options) return x
def test_lsymeig_mismatch_err(): torch.manual_seed(seed) mat1 = torch.rand((3, 3)) mat2 = torch.rand((2, 2)) mat1 = mat1 + mat1.transpose(-2, -1) mat2 = mat2 + mat2.transpose(-2, -1) linop1 = LinearOperator.m(mat1, True) linop2 = LinearOperator.m(mat2, True) try: res = lsymeig(linop1, M=linop2) assert False, "A RuntimeError must be raised if A & M shape are mismatch" except RuntimeError: pass
def test_lsymeig_AM(): torch.manual_seed(seed) shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)] methods = ["exacteig", "davidson"] dtype = torch.float64 for ashape, mshape, method in itertools.product(shapes, shapes, methods): mata = torch.rand(ashape, dtype=dtype) matm = torch.rand(mshape, dtype=dtype) + torch.eye( mshape[-1], dtype=dtype) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() linopa = LinearOperator.m(mata, True) linopm = LinearOperator.m(matm, True) fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) for neig in [2, ashape[-1]]: eigvals, eigvecs = lsymeig( linopa, M=linopm, neig=neig, fwd_options=fwd_options) # eigvals: (..., neig) assert list(eigvals.shape) == list([*bshape, neig]) assert list(eigvecs.shape) == list([*bshape, na, neig]) ax = linopa.mm(eigvecs) mxe = linopm.mm( torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))) assert torch.allclose(ax, mxe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == ashape[-1]: def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, fwd_options=fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mata, matm)) gradgradcheck(lsymeig_fcn, (mata, matm))
def test_solve_A_gmres(): torch.manual_seed(seed) na = 3 dtype = torch.float64 ashape = (na, na) bshape = (2, na, na) fwd_options = {"method": "gmres"} ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) amat = amat + amat.transpose(-2, -1) amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, fwd_options=fwd_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat) gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def exacteig(A: LinearOperator, neig: Union[int, None], mode: str, M: Union[LinearOperator, None]): Amatrix = A.fullmatrix() # (*BA, q, q) if neig is None: neig = A.shape[-1] if M is None: evals, evecs = torch.symeig(Amatrix, eigenvectors=True) # (*BA, q), (*BA, q, q) return _take_eigpairs(evals, evecs, neig, mode) else: Mmatrix = M.fullmatrix() # (*BM, q, q) # M decomposition to make A symmetric # it is done this way to make it numerically stable in avoiding # complex eigenvalues for (near-)degenerate case L = torch.cholesky(Mmatrix, upper=False) # (*BM, q, q) Linv = torch.inverse(L) # (*BM, q, q) LinvT = Linv.transpose(-2, -1) # (*BM, q, q) A2 = torch.matmul(Linv, torch.matmul(Amatrix, LinvT)) # (*BAM, q, q) # calculate the eigenvalues and eigenvectors # (the eigvecs are normalized in M-space) evals, evecs = torch.symeig(A2, eigenvectors=True) # (*BAM, q, q) evals, evecs = _take_eigpairs(evals, evecs, neig, mode) # (*BAM, neig) and (*BAM, q, neig) evecs = torch.matmul(LinvT, evecs) return evals, evecs
def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, fwd_options=fwd_options, bck_options=bck_options) return x
def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(amat, is_hermitian=hermit) x = solve(A=alinop, B=bmat, fwd_options=fwd_options, bck_options=bck_options) return x
def test_lsymeig_nonhermit_err(): torch.manual_seed(seed) mat = torch.rand((3, 3)) linop = LinearOperator.m(mat, False) linop2 = LinearOperator.m(mat + mat.transpose(-2, -1), True) try: res = lsymeig(linop) assert False, "A RuntimeError must be raised if the A linear operator in lsymeig is not Hermitian" except RuntimeError: pass try: res = lsymeig(linop2, M=linop) assert False, "A RuntimeError must be raised if the M linear operator in lsymeig is not Hermitian" except RuntimeError: pass
def test_solve_nonsquare_err(): torch.manual_seed(seed) mat = torch.rand((3, 2)) mat2 = torch.rand((3, 3)) linop = LinearOperator.m(mat) linop2 = LinearOperator.m(mat2) B = torch.rand(3, 1) try: res = solve(linop, B) assert False, "A RuntimeError must be raised if the A linear operator in solve not square" except RuntimeError: pass try: res = solve(linop2, B, M=linop) assert False, "A RuntimeError must be raised if the M linear operator in solve is not square" except RuntimeError: pass
def exactsolve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) if E is None: Amatrix = A.fullmatrix() # (*BA, na, na) x, _ = torch.solve(B, Amatrix) # (*BAB, na, ncols) elif M is None: Amatrix = A.fullmatrix() x = _solve_ABE(Amatrix, B, E) else: Mmatrix = M.fullmatrix() # (*BM, na, na) L = torch.cholesky(Mmatrix, upper=False) # (*BM, na, na) Linv = torch.inverse(L) # (*BM, na, na) LinvT = Linv.transpose(-2, -1) # (*BM, na, na) A2 = torch.matmul(Linv, A.mm(LinvT)) # (*BAM, na, na) B2 = torch.matmul(Linv, B) # (*BBM, na, ncols) X2 = _solve_ABE(A2, B2, E) # (*BABEM, na, ncols) x = torch.matmul(LinvT, X2) # (*BABEM, na, ncols) return x
def test_solve_mismatch_err(): torch.manual_seed(seed) shapes = [ # A B M ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"), ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"), ] dtype = torch.float64 for (ashape, bshape, mshape), msg in shapes: amat = torch.rand(ashape, dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) mmat = torch.rand(mshape, dtype=dtype) + torch.eye(mshape[-1], dtype=dtype) amat = amat + amat.transpose(-2, -1) mmat = mmat + mmat.transpose(-2, -1) alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) try: res = solve(alinop, B=bmat, M=mlinop) assert False, "A RuntimeError must be raised if %s" % msg except RuntimeError: pass
def test_solve_A(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] dtype = torch.float64 # custom exactsolve to check the backward implementation methods = ["exactsolve", "custom_exactsolve", "lbfgs"] # hermitian check here to make sure the gradient propagated symmetrically hermits = [False, True] for ashape, bshape, method, hermit in itertools.product( shapes, shapes, methods, hermits): print(ashape, bshape, method, hermit) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) if hermit: amat = (amat + amat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(amat, is_hermitian=hermit) x = solve(A=alinop, B=bmat, fwd_options=fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat) if checkgrad: gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_solve_AE(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] methods = ["exactsolve", "custom_exactsolve", "lbfgs" ] # custom exactsolve to check the backward implementation dtype = torch.float64 for ashape, bshape, eshape, method in itertools.product( shapes, shapes, shapes, methods): print(ashape, bshape, eshape, method) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1])) + [na, ncols] fwd_options = {"method": method} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) emat = torch.rand(eshape, dtype=dtype) amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, fwd_options=fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)) assert torch.allclose(ax - xe, bmat)
def test_lsymeig_A(): torch.manual_seed(seed) shapes = [(4, 4), (2, 4, 4), (2, 3, 4, 4)] methods = ["exacteig", "davidson"] for shape, method in itertools.product(shapes, methods): mat1 = torch.rand(shape, dtype=torch.float64) mat1 = mat1 + mat1.transpose(-2, -1) mat1 = mat1.requires_grad_() linop1 = LinearOperator.m(mat1, True) fwd_options = {"method": method} for neig in [2, shape[-1]]: eigvals, eigvecs = lsymeig( linop1, neig=neig, fwd_options=fwd_options ) # eigvals: (..., neig), eigvecs: (..., na, neig) assert list(eigvecs.shape) == list([*linop1.shape[:-1], neig]) assert list(eigvals.shape) == list([*linop1.shape[:-2], neig]) ax = linop1.mm(eigvecs) xe = torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1)) assert torch.allclose(ax, xe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == shape[-1]: def lsymeig_fcn(amat): amat = (amat + amat.transpose(-2, -1)) * 0.5 # symmetrize alinop = LinearOperator.m(amat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, neig=neig, fwd_options=fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mat1, )) gradgradcheck(lsymeig_fcn, (mat1, ))
def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Union[LinearOperator, None] = None, posdef=False, fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Performing iterative method to solve the equation AX=B or AX-MXE=B, where E is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying A to a tensor X with shape (...,na,ncols). The applied E are not necessarily identical for each column. Arguments --------- * A: xitorch.LinearOperator instance with shape (*BA, na, na) A function that takes an input X and produce the vectors in the same space as B. * B: torch.tensor (*BB, na, ncols) The tensor on the right hand side. * E: torch.tensor (*BE, ncols) or None If not None, it will solve AX-MXE = B. Otherwise, it just solves AX = B and M is ignored. E would be applied to every column. * M: xitorch.LinearOperator instance (*BM, na, na) or None The transformation on the E side. If E is None, then this argument is ignored. I E is not None and M is None, then M=I. This LinearOperator must be Hermitian. * fwd_options: dict Options of the iterative solver in the forward calculation * bck_options: dict Options of the iterative solver in the backward calculation """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime( A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") if E is not None: assert_runtime( E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn( "M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if "method" not in fwd_options or fwd_options["method"].lower( ) == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply(A, B, E, M, posdef, fwd_options, bck_options, na, *params, *mparams)
def symeig(A: LinearOperator, neig: Union[int, None] = None, mode: str = "lowest", M: Union[LinearOperator, None] = None, fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Obtain `neig` lowest eigenvalues and eigenvectors of a linear operator. If M is specified, it solve the eigendecomposition Ax = eMx. Arguments --------- * A: xitorch.LinearOperator hermitian instance with shape (*BA, q, q) The linear module object on which the eigenpairs are constructed. * neig: int or None The number of eigenpairs to be retrieved. If None, all eigenpairs are retrieved * mode: str "lowest" or "uppermost"/"uppest". If "lowest", it will take the lowest `neig` eigenpairs. If "uppest", it will take the uppermost `neig`. * M: xitorch.LinearOperator hermitian instance with shape (*BM, q, q) or None The transformation on the right hand side. If None, then M=I. * fwd_options: dict with str as key Eigendecomposition iterative algorithm options. * bck_options: dict with str as key Conjugate gradient options to calculate the gradient in backpropagation calculation. Returns ------- * eigvals: (*BAM, neig) * eigvecs: (*BAM, na, neig) The lowest eigenvalues and eigenvectors, where *BAM are the broadcasted shape of *BA and *BM. """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) mode = mode.lower() if mode == "uppermost": mode = "uppest" # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if "method" not in fwd_options or fwd_options["method"].lower( ) == "exacteig": return exacteig(A, neig, mode, M) else: # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def test_solve_AEM(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] dtype = torch.float64 methods = ["exactsolve", "custom_exactsolve", "lbfgs"] for abeshape, mshape, method in itertools.product(shapes, shapes, methods): ashape = abeshape bshape = abeshape eshape = abeshape print(abeshape, mshape, method) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list( get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1], mshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = { "method": method } # exactsolve at backward just to test the forward solve amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) mmat = torch.rand(mshape, dtype=dtype) * 0.1 + torch.eye( mshape[-1], dtype=dtype) * 0.5 bmat = torch.rand(bshape, dtype=dtype) emat = torch.rand(eshape, dtype=dtype) mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() mmat = mmat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, fwd_options=fwd_options, bck_options=bck_options) return x x = solvefcn(amat, mmat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) mxe = LinearOperator.m(mmat).mm( torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))) y = ax - mxe assert torch.allclose(y, bmat) # gradient checker if checkgrad: gradcheck(solvefcn, (amat, mmat, bmat, emat)) gradgradcheck(solvefcn, (amat, mmat, bmat, emat))