def _test_lsymeig(): torch.manual_seed(seed) mata = torch.rand(ashape, dtype=dtype, device=device) matm = torch.rand(mshape, dtype=dtype, device=device) + \ torch.eye(mshape[-1], dtype=dtype, device=device) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) neig = ashape[-1] def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ eival, eivec = lsymeig_fcn(mata, matm) loss = (eival * eival).sum() + (eivec * eivec).sum() # using autograd.grad instead of .backward because backward has a known # memory leak problem grads = torch.autograd.grad(loss, (mata, matm), create_graph=True)
def test_solve_A_gmres(): torch.manual_seed(seed) na = 3 dtype = torch.float64 ashape = (na, na) bshape = (2, na, na) fwd_options = {"method": "gmres"} ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] amat = torch.rand(ashape, dtype=dtype) + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) amat = amat + amat.transpose(-2, -1) amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, fwd_options=fwd_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat) gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_solve_AEM(dtype, device, abeshape, mshape, method): torch.manual_seed(seed) na = abeshape[-1] ashape = abeshape bshape = abeshape eshape = abeshape checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list( get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1], mshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = { "method": method } # exactsolve at backward just to test the forward solve amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \ torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5 bmat = torch.rand(bshape, dtype=dtype, device=device) emat = torch.rand(eshape, dtype=dtype, device=device) mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() mmat = mmat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, mmat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) mxe = LinearOperator.m(mmat).mm( torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2))) y = ax - mxe assert torch.allclose(y, bmat) # gradient checker if checkgrad: gradcheck(solvefcn, (amat, mmat, bmat, emat)) gradgradcheck(solvefcn, (amat, mmat, bmat, emat))
def _get_batchdims(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): batchdims = [A.shape[:-2], B.shape[:-2]] if E is not None: batchdims.append(E.shape[:-1]) if M is not None: batchdims.append(M.shape[:-2]) return get_bcasted_dims(*batchdims)
def __init__(self, a: LinearOperator, b: LinearOperator, is_hermitian: bool = False): shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2], b.shape[-1]) super(MatmulLinearOperator, self).__init__( shape=shape, is_hermitian=is_hermitian, dtype=a.dtype, device=a.device, _suppress_hermit_warning=True, ) self.a = a self.b = b
def __init__(self, a: LinearOperator, b: LinearOperator): shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2], b.shape[-1]) is_hermitian = a.is_hermitian and b.is_hermitian super(AddLinearOperator, self).__init__( shape=shape, is_hermitian=is_hermitian, dtype=a.dtype, device=a.device, _suppress_hermit_warning=True, ) self.a = a self.b = b
def test_lsymeig_AM(): torch.manual_seed(seed) shapes = [(3, 3), (2, 3, 3), (2, 1, 3, 3)] # only 2 of methods, because both gradient implementations are covered methods = ["exacteig", "custom_exacteig"] dtype = torch.float64 for ashape, mshape, method in itertools.product(shapes, shapes, methods): mata = torch.rand(ashape, dtype=dtype) matm = torch.rand(mshape, dtype=dtype) + torch.eye( mshape[-1], dtype=dtype) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() linopa = LinearOperator.m(mata, True) linopm = LinearOperator.m(matm, True) fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) for neig in [2, ashape[-1]]: eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig, **fwd_options) # eigvals: (..., neig) assert list(eigvals.shape) == list([*bshape, neig]) assert list(eigvecs.shape) == list([*bshape, na, neig]) ax = linopa.mm(eigvecs) mxe = linopm.mm( torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))) assert torch.allclose(ax, mxe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == ashape[-1]: def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mata, matm)) gradgradcheck(lsymeig_fcn, (mata, matm))
def _test_solve(): torch.manual_seed(seed) na = abeshape[-1] ashape = abeshape bshape = abeshape eshape = abeshape checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list( get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1], mshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = { "method": method } # exactsolve at backward just to test the forward solve amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) mmat = torch.rand(mshape, dtype=dtype, device=device) * 0.1 + \ torch.eye(mshape[-1], dtype=dtype, device=device) * 0.5 bmat = torch.rand(bshape, dtype=dtype, device=device) emat = torch.rand(eshape, dtype=dtype, device=device) mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() mmat = mmat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options, bck_options=bck_options) return x loss = (solvefcn(amat, mmat, bmat, emat)**2).sum() # using autograd.grad instead of .backward because backward has a known # memory leak problem grads = torch.autograd.grad(loss, (amat, mmat, bmat, emat), create_graph=True)
def test_solve_A(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] dtype = torch.float64 # custom exactsolve to check the backward implementation methods = ["exactsolve", "custom_exactsolve", "lbfgs"] # hermitian check here to make sure the gradient propagated symmetrically hermits = [False, True] for ashape, bshape, method, hermit in itertools.product( shapes, shapes, methods, hermits): print(ashape, bshape, method, hermit) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) if hermit: amat = (amat + amat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(amat, is_hermitian=hermit) x = solve(A=alinop, B=bmat, fwd_options=fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat) if checkgrad: gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_solve_AE(): torch.manual_seed(seed) na = 2 shapes = [(na, na), (2, na, na), (2, 1, na, na)] methods = ["exactsolve", "custom_exactsolve", "lbfgs" ] # custom exactsolve to check the backward implementation dtype = torch.float64 for ashape, bshape, eshape, method in itertools.product( shapes, shapes, shapes, methods): print(ashape, bshape, eshape, method) checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1])) + [na, ncols] fwd_options = {"method": method} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype) * 0.1 + torch.eye(ashape[-1], dtype=dtype) bmat = torch.rand(bshape, dtype=dtype) emat = torch.rand(eshape, dtype=dtype) amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, fwd_options=fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)) assert torch.allclose(ax - xe, bmat)
def test_lsymeig_AM(dtype, device, ashape, mshape, method): torch.manual_seed(seed) mata = torch.rand(ashape, dtype=dtype, device=device) matm = torch.rand(mshape, dtype=dtype, device=device) + \ torch.eye(mshape[-1], dtype=dtype, device=device) # make sure it's not singular mata = mata + mata.transpose(-2, -1) matm = matm + matm.transpose(-2, -1) mata = mata.requires_grad_() matm = matm.requires_grad_() linopa = LinearOperator.m(mata, True) linopm = LinearOperator.m(matm, True) fwd_options = {"method": method} na = ashape[-1] bshape = get_bcasted_dims(ashape[:-2], mshape[:-2]) for neig in [2, ashape[-1]]: eigvals, eigvecs = lsymeig(linopa, M=linopm, neig=neig, **fwd_options) # eigvals: (..., neig) assert list(eigvals.shape) == list([*bshape, neig]) assert list(eigvecs.shape) == list([*bshape, na, neig]) ax = linopa.mm(eigvecs) mxe = linopm.mm( torch.matmul(eigvecs, torch.diag_embed(eigvals, dim1=-2, dim2=-1))) assert torch.allclose(ax, mxe) # only perform gradcheck if neig is full, to reduce the computational cost if neig == ashape[-1]: def lsymeig_fcn(amat, mmat): # symmetrize amat = (amat + amat.transpose(-2, -1)) * 0.5 mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat, is_hermitian=True) mlinop = LinearOperator.m(mmat, is_hermitian=True) eigvals_, eigvecs_ = lsymeig(alinop, M=mlinop, neig=neig, **fwd_options) return eigvals_, eigvecs_ gradcheck(lsymeig_fcn, (mata, matm)) gradgradcheck(lsymeig_fcn, (mata, matm))
def __adjoint_rmv(self, xt: torch.Tensor) -> torch.Tensor: # xt: (*BY, p) # xdummy: (*BY, q) # calculate the right matvec multiplication by using the adjoint trick BY = xt.shape[:-1] BA = self.shape[:-2] BAY = get_bcasted_dims(BY, BA) # calculate y = Ax p, q = self.shape[-2:] xdummy = torch.zeros((*BAY, q), dtype=xt.dtype, device=xt.device).requires_grad_() with torch.enable_grad(): y = self.mv(xdummy) # (*BAY, p) # calculate (dL/dx)^T = A^T (dL/dy)^T with (dL/dy)^T = xt xt2 = xt.contiguous().expand_as(y) # (*BAY, p) res = torch.autograd.grad(y, xdummy, grad_outputs=xt2, create_graph=torch.is_grad_enabled())[0] # (*BAY, q) return res
def test_solve_A_methods(dtype, device, method): torch.manual_seed(seed) na = 100 ashape = (na, na) bshape = (2, na, na) options = { "scipy_gmres": {}, "broyden1": {}, "cg": { "rtol": 1e-8 # stringent rtol required to meet the torch.allclose tols }, "bicgstab": { "rtol": 1e-8, }, }[method] fwd_options = {"method": method, **options} ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) + 0.1 amat = (amat + amat.transpose(-2, -1)) * 0.5 amat = amat.requires_grad_() bmat = bmat.requires_grad_() def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, **fwd_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) assert torch.allclose(ax, bmat)
def test_solve_A(dtype, device, ashape, bshape, method, hermit): torch.manual_seed(seed) na = ashape[-1] checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2])) + [na, ncols] fwd_options = {"method": method, "min_eps": 1e-9} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) amat = amat.requires_grad_() bmat = bmat.requires_grad_() def prepare(amat): if hermit: return (amat + amat.transpose(-2, -1)) * 0.5 return amat def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(prepare(amat), is_hermitian=hermit) x = solve(A=alinop, B=bmat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat) assert list(x.shape) == xshape ax = LinearOperator.m(prepare(amat)).mm(x) assert torch.allclose(ax, bmat) if checkgrad: gradcheck(solvefcn, (amat, bmat)) gradgradcheck(solvefcn, (amat, bmat))
def test_solve_AE(dtype, device, ashape, bshape, eshape, method): torch.manual_seed(seed) na = ashape[-1] checkgrad = method.endswith("exactsolve") ncols = bshape[-1] - 1 bshape = [*bshape[:-1], ncols] eshape = [*eshape[:-2], ncols] xshape = list(get_bcasted_dims(ashape[:-2], bshape[:-2], eshape[:-1])) + [na, ncols] fwd_options = {"method": method} bck_options = {"method": method} amat = torch.rand(ashape, dtype=dtype, device=device) * 0.1 + \ torch.eye(ashape[-1], dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) emat = torch.rand(eshape, dtype=dtype, device=device) amat = amat.requires_grad_() bmat = bmat.requires_grad_() emat = emat.requires_grad_() def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, **fwd_options, bck_options=bck_options) return x x = solvefcn(amat, bmat, emat) assert list(x.shape) == xshape ax = LinearOperator.m(amat).mm(x) xe = torch.matmul(x, torch.diag_embed(emat, dim2=-1, dim1=-2)) assert torch.allclose(ax - xe, bmat)
def davidson(A: LinearOperator, neig: int, mode: str, M: Optional[LinearOperator] = None, max_niter: int = 1000, nguess: Optional[int] = None, v_init: str = "randn", max_addition: Optional[int] = None, min_eps: float = 1e-6, verbose: bool = False, **unused) -> Tuple[torch.Tensor, torch.Tensor]: """ Using Davidson method for large sparse matrix eigendecomposition [1]_. Arguments --------- max_niter: int Maximum number of iterations v_init: str Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``) max_addition: int or None Maximum number of new guesses to be added to the collected vectors. If None, set to ``neig``. min_eps: float Minimum residual error to be stopped verbose: bool Option to be verbose References ---------- .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems" http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf """ # TODO: optimize for large linear operator and strict min_eps # Ideas: # (1) use better strategy to get the estimate on eigenvalues # (2) use restart strategy if nguess is None: nguess = neig if max_addition is None: max_addition = neig # get the shape of the transformation na = A.shape[-1] if M is None: bcast_dims = A.shape[:-2] else: bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2]) dtype = A.dtype device = A.device prev_eigvals = None prev_eigvalT = None stop_reason = "max_niter" shift_is_eigvalT = False idx = torch.arange(neig).unsqueeze(-1) # (neig, 1) # set up the initial guess V = _set_initial_v(v_init.lower(), dtype, device, bcast_dims, na, nguess, M=M) # (*BAM, na, nguess) best_resid: Union[float, torch.Tensor] = float("inf") AV = A.mm(V) for i in range(max_niter): VT = V.transpose(-2, -1) # (*BAM,nguess,na) # Can be optimized by saving AV from the previous iteration and only # operate AV for the new V. This works because the old V has already # been orthogonalized, so it will stay the same # AV = A.mm(V) # (*BAM,na,nguess) T = torch.matmul(VT, AV) # (*BAM,nguess,nguess) # eigvals are sorted from the lowest # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess) eigvalT, eigvecT = torch.symeig(T, eigenvectors=True) eigvalT, eigvecT = _take_eigpairs(eigvalT, eigvecT, neig, mode) # (*BAM, neig) and (*BAM, nguess, neig) # calculate the eigenvectors of A eigvecA = torch.matmul(V, eigvecT) # (*BAM, na, neig) # calculate the residual AVs = torch.matmul(AV, eigvecT) # (*BAM, na, neig) LVs = eigvalT.unsqueeze(-2) * eigvecA # (*BAM, na, neig) if M is not None: LVs = M.mm(LVs) resid = AVs - LVs # (*BAM, na, neig) # print information and check convergence max_resid = resid.abs().max() if prev_eigvalT is not None: deigval = eigvalT - prev_eigvalT max_deigval = deigval.abs().max() if verbose: print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % (i + 1, nguess, max_resid, max_deigval)) # type:ignore if max_resid < best_resid: best_resid = max_resid best_eigvals = eigvalT best_eigvecs = eigvecA if max_resid < min_eps: break if AV.shape[-1] == AV.shape[-2]: break prev_eigvalT = eigvalT # apply the preconditioner t = -resid # (*BAM, na, neig) # orthogonalize t with the rest of the V t = to_fortran_order(t) Vnew = torch.cat((V, t), dim=-1) if Vnew.shape[-1] > Vnew.shape[-2]: Vnew = Vnew[..., :Vnew.shape[-2]] nadd = Vnew.shape[-1] - V.shape[-1] nguess = nguess + nadd if M is not None: MV_ = M.mm(Vnew) V, R = tallqr(Vnew, MV=MV_) else: V, R = tallqr(Vnew) AVnew = A.mm(V[..., -nadd:]) # (*BAM,na,nadd) AVnew = to_fortran_order(AVnew) AV = torch.cat((AV, AVnew), dim=-1) eigvals = best_eigvals # (*BAM, neig) eigvecs = best_eigvecs # (*BAM, na, neig) return eigvals, eigvecs
def davidson(A, params, neig, mode, M=None, mparams=[], **options): """ Iterative methods to obtain the `neig` lowest eigenvalues and eigenvectors. This function is written so that the backpropagation can be done. It solves the eigendecomposition AV = VME where V are the matrix of eigenvectors, and E are the diagonal matrix consists of the eigenvalues. Arguments --------- * A: LinearOperator instance (*BA, na, na) The linear operator object on which the eigenpairs are constructed. * params: list of differentiable torch.tensor of any shapes List of differentiable torch.tensor to be put to A. * neig: int The number of eigenpairs to be retrieved. * mode: str Take the `neig` "lowest" or "uppest" of the eigenpairs. * M: LinearOperator instance (*BM, na, na) or None The transformation on the right hand side. If None, then M=I. * mparams: list of differentiable torch.tensor of any shapes List of differentiable torch.tensor to be put to M. * **options: Iterative algorithm options. Returns ------- * eigvals: torch.tensor (*BAM, neig) * eigvecs: torch.tensor (*BAM, na, neig) The `neig` lowest eigenpairs """ # TODO: optimize for large linear operator and strict min_eps # Ideas: # (1) use better strategy to get the estimate on eigenvalues # (2) use restart strategy config = set_default_option( { "max_niter": 1000, "nguess": neig, # number of initial guess "min_eps": 1e-6, "verbose": False, "eps_cond": 1e-6, "v_init": "randn", "max_addition": neig, }, options) # get some of the options nguess = config["nguess"] max_niter = config["max_niter"] min_eps = config["min_eps"] verbose = config["verbose"] eps_cond = config["eps_cond"] max_addition = config["max_addition"] # get the shape of the transformation na = A.shape[-1] if M is None: bcast_dims = A.shape[:-2] else: bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2]) dtype = A.dtype device = A.device # TODO: A to use params prev_eigvals = None prev_eigvalT = None stop_reason = "max_niter" shift_is_eigvalT = False idx = torch.arange(neig).unsqueeze(-1) # (neig, 1) with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): # set up the initial guess V = _set_initial_v(config["v_init"].lower(), dtype, device, bcast_dims, na, nguess, M=M, mparams=mparams) # (*BAM, na, nguess) # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess) # estimating the lowest eigenvalues eig_est, rms_eig = _estimate_eigvals(A, neig, mode, bcast_dims=bcast_dims, na=na, ntest=20, dtype=V.dtype, device=V.device) best_resid = float("inf") AV = A.mm(V) for i in range(max_niter): VT = V.transpose(-2, -1) # (*BAM,nguess,na) # Can be optimized by saving AV from the previous iteration and only # operate AV for the new V. This works because the old V has already # been orthogonalized, so it will stay the same # AV = A.mm(V) # (*BAM,na,nguess) T = torch.matmul(VT, AV) # (*BAM,nguess,nguess) # eigvals are sorted from the lowest # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess) eigvalT, eigvecT = torch.symeig(T, eigenvectors=True) eigvalT, eigvecT = _take_eigpairs( eigvalT, eigvecT, neig, mode) # (*BAM, neig) and (*BAM, nguess, neig) # calculate the eigenvectors of A eigvecA = torch.matmul(V, eigvecT) # (*BAM, na, neig) # calculate the residual AVs = torch.matmul(AV, eigvecT) # (*BAM, na, neig) LVs = eigvalT.unsqueeze(-2) * eigvecA # (*BAM, na, neig) if M is not None: LVs = M.mm(LVs) resid = AVs - LVs # (*BAM, na, neig) # print information and check convergence max_resid = resid.abs().max() if prev_eigvalT is not None: deigval = eigvalT - prev_eigvalT max_deigval = deigval.abs().max() if verbose: print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \ (i+1, nguess, max_resid, max_deigval)) if max_resid < best_resid: best_resid = max_resid best_eigvals = eigvalT best_eigvecs = eigvecA if max_resid < min_eps: break if AV.shape[-1] == AV.shape[-2]: break prev_eigvalT = eigvalT # apply the preconditioner # initial guess of the eigenvalues are actually help really much if not shift_is_eigvalT: z = eig_est # (*BAM,neig) else: z = eigvalT # (*BAM,neig) # if A.is_precond_set(): # t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig) # else: t = -resid # (*BAM, na, neig) # set the estimate of the eigenvalues if not shift_is_eigvalT: eigvalT_pred = eigvalT + torch.einsum( '...ae,...ae->...e', eigvecA, A.mm(t)) # (*BAM, neig) diff_eigvalT = (eigvalT - eigvalT_pred) # (*BAM, neig) if diff_eigvalT.abs().max() < rms_eig * 1e-2: shift_is_eigvalT = True else: change_idx = eig_est > eigvalT next_value = eigvalT - 2 * diff_eigvalT eig_est[change_idx] = next_value[change_idx] # orthogonalize t with the rest of the V t = to_fortran_order(t) Vnew = torch.cat((V, t), dim=-1) if Vnew.shape[-1] > Vnew.shape[-2]: Vnew = Vnew[..., :Vnew.shape[-2]] nadd = Vnew.shape[-1] - V.shape[-1] nguess = nguess + nadd if M is not None: MV_ = M.mm(Vnew) V, R = tallqr(Vnew, MV=MV_) else: V, R = tallqr(Vnew) AVnew = A.mm(V[..., -nadd:]) # (*BAM,na,nadd) AVnew = to_fortran_order(AVnew) AV = torch.cat((AV, AVnew), dim=-1) eigvals = best_eigvals # (*BAM, neig) eigvecs = best_eigvecs # (*BAM, na, neig) return eigvals, eigvecs
def davidson(A, params, neig, mode, M=None, mparams=[], max_niter=1000, nguess=None, v_init="randn", max_addition=None, min_eps=1e-6, verbose=False, **unused): """ Using Davidson method for large sparse matrix eigendecomposition [1]_. Arguments --------- max_niter: int Maximum number of iterations nguess: int or None The number of initial guess of the eigenvectors If None, set to ``neig``. v_init: str Mode of the initial guess (``"randn"``, ``"rand"``, ``"eye"``) max_addition: int or None Maximum number of new guesses to be added to the collected vectors. If None, set to ``neig``. min_eps: float Minimum residual error to be stopped verbose: bool Option to be verbose References ---------- .. [1] P. Arbenz, "Lecture Notes on Solving Large Scale Eigenvalue Problems" http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter12.pdf """ # TODO: optimize for large linear operator and strict min_eps # Ideas: # (1) use better strategy to get the estimate on eigenvalues # (2) use restart strategy if nguess is None: nguess = neig if max_addition is None: max_addition = neig # get the shape of the transformation na = A.shape[-1] if M is None: bcast_dims = A.shape[:-2] else: bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2]) dtype = A.dtype device = A.device # TODO: A to use params prev_eigvals = None prev_eigvalT = None stop_reason = "max_niter" shift_is_eigvalT = False idx = torch.arange(neig).unsqueeze(-1) # (neig, 1) with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): # set up the initial guess V = _set_initial_v(v_init.lower(), dtype, device, bcast_dims, na, nguess, M=M, mparams=mparams) # (*BAM, na, nguess) # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess) # estimating the lowest eigenvalues eig_est, rms_eig = _estimate_eigvals(A, neig, mode, bcast_dims=bcast_dims, na=na, ntest=20, dtype=V.dtype, device=V.device) best_resid = float("inf") AV = A.mm(V) for i in range(max_niter): VT = V.transpose(-2, -1) # (*BAM,nguess,na) # Can be optimized by saving AV from the previous iteration and only # operate AV for the new V. This works because the old V has already # been orthogonalized, so it will stay the same # AV = A.mm(V) # (*BAM,na,nguess) T = torch.matmul(VT, AV) # (*BAM,nguess,nguess) # eigvals are sorted from the lowest # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess) eigvalT, eigvecT = torch.symeig(T, eigenvectors=True) eigvalT, eigvecT = _take_eigpairs( eigvalT, eigvecT, neig, mode) # (*BAM, neig) and (*BAM, nguess, neig) # calculate the eigenvectors of A eigvecA = torch.matmul(V, eigvecT) # (*BAM, na, neig) # calculate the residual AVs = torch.matmul(AV, eigvecT) # (*BAM, na, neig) LVs = eigvalT.unsqueeze(-2) * eigvecA # (*BAM, na, neig) if M is not None: LVs = M.mm(LVs) resid = AVs - LVs # (*BAM, na, neig) # print information and check convergence max_resid = resid.abs().max() if prev_eigvalT is not None: deigval = eigvalT - prev_eigvalT max_deigval = deigval.abs().max() if verbose: print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \ (i+1, nguess, max_resid, max_deigval)) if max_resid < best_resid: best_resid = max_resid best_eigvals = eigvalT best_eigvecs = eigvecA if max_resid < min_eps: break if AV.shape[-1] == AV.shape[-2]: break prev_eigvalT = eigvalT # apply the preconditioner # initial guess of the eigenvalues are actually help really much if not shift_is_eigvalT: z = eig_est # (*BAM,neig) else: z = eigvalT # (*BAM,neig) # if A.is_precond_set(): # t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig) # else: t = -resid # (*BAM, na, neig) # set the estimate of the eigenvalues if not shift_is_eigvalT: eigvalT_pred = eigvalT + torch.einsum( '...ae,...ae->...e', eigvecA, A.mm(t)) # (*BAM, neig) diff_eigvalT = (eigvalT - eigvalT_pred) # (*BAM, neig) if diff_eigvalT.abs().max() < rms_eig * 1e-2: shift_is_eigvalT = True else: change_idx = eig_est > eigvalT next_value = eigvalT - 2 * diff_eigvalT eig_est[change_idx] = next_value[change_idx] # orthogonalize t with the rest of the V t = to_fortran_order(t) Vnew = torch.cat((V, t), dim=-1) if Vnew.shape[-1] > Vnew.shape[-2]: Vnew = Vnew[..., :Vnew.shape[-2]] nadd = Vnew.shape[-1] - V.shape[-1] nguess = nguess + nadd if M is not None: MV_ = M.mm(Vnew) V, R = tallqr(Vnew, MV=MV_) else: V, R = tallqr(Vnew) AVnew = A.mm(V[..., -nadd:]) # (*BAM,na,nadd) AVnew = to_fortran_order(AVnew) AV = torch.cat((AV, AVnew), dim=-1) eigvals = best_eigvals # (*BAM, neig) eigvecs = best_eigvecs # (*BAM, na, neig) return eigvals, eigvecs