def test_gmres_arnoldi_step(dtype): """ The Arnoldi decomposition within GMRES is correct. """ gmres = jitted_functions.gmres_wrapper(jax) dummy = jax.numpy.zeros(1, dtype=dtype) dtype = dummy.dtype n = 4 n_kry = n np.random.seed(10) A = jax.numpy.array(np.random.rand(n, n).astype(dtype)) x0 = jax.numpy.array(np.random.rand(n).astype(dtype)) Q = np.zeros((n, n_kry + 1), dtype=x0.dtype) Q[:, 0] = x0/jax.numpy.linalg.norm(x0) Q = jax.numpy.array(Q) H = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype) tol = A.size*jax.numpy.finfo(dtype).eps @jax.tree_util.Partial def A_mv(x): return A @ x for k in range(n_kry): Q, H = gmres.kth_arnoldi_step(k, A_mv, [], Q, H, tol) QAQ = Q[:, :n_kry].conj().T @ A @ Q[:, :n_kry] np.testing.assert_allclose(H[:n_kry, :], QAQ, atol=tol)
def test_givens(dtype): """ gmres["givens_rotation produces the correct rotation factors. """ gmres = jitted_functions.gmres_wrapper(jax) np.random.seed(10) v = jax.numpy.array(np.random.rand(2).astype(dtype)) cs, sn = gmres.givens_rotation(*v) rot = np.zeros((2, 2), dtype=dtype) rot[0, 0] = cs rot[1, 1] = cs rot[0, 1] = -sn rot[1, 0] = sn rot = jax.numpy.array(rot) result = rot @ v tol = 4*jax.numpy.finfo(dtype).eps np.testing.assert_allclose(result[-1], 0., atol=tol)
def test_gs(dtype): """ The Gram-Schmidt process works. """ gmres = jitted_functions.gmres_wrapper(jax) dummy = jax.numpy.zeros(1, dtype=dtype) dtype = dummy.dtype n = 8 A = np.zeros((n, 2), dtype=dtype) A[:-1, 0] = 1.0 Ai = A[:, 0] / np.linalg.norm(A[:, 0]) A[:, 0] = Ai A[-1, -1] = 1.0 A = jax.numpy.array(A) x0 = jax.numpy.array(np.random.rand(n).astype(dtype)) v_new, _ = jax.lax.scan(gmres.gs_step, x0, xs=A.T) dotcheck = v_new @ A tol = A.size*jax.numpy.finfo(dtype).eps np.testing.assert_allclose(dotcheck, np.zeros(2), atol=tol)
def test_gmres_on_small_known_problem(dtype): """ GMRES produces the correct result on an analytically solved linear system. """ dummy = jax.numpy.zeros(1, dtype=dtype) dtype = dummy.dtype gmres = jitted_functions.gmres_wrapper(jax) A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype) b = jax.numpy.array([3, 2], dtype=dtype) x0 = jax.numpy.ones(2, dtype=dtype) n_kry = 2 maxiter = 1 @jax.tree_util.Partial def A_mv(x): return A @ x tol = A.size*jax.numpy.finfo(dtype).eps x, _, _, _ = gmres.gmres_m(A_mv, [], b, x0, tol, tol, n_kry, maxiter) solution = jax.numpy.array([2., 1.], dtype=dtype) np.testing.assert_allclose(x, solution, atol=tol)
def test_gmres_krylov(dtype): """ gmres_krylov correctly builds the QR-decomposed Arnoldi decomposition. This function assumes that gmres["kth_arnoldi_step (which is independently tested) is correct. """ dummy = jax.numpy.zeros(1, dtype=dtype) dtype = dummy.dtype gmres = jitted_functions.gmres_wrapper(jax) n = 2 n_kry = n np.random.seed(10) @jax.tree_util.Partial def A_mv(x): return A @ x A = jax.numpy.array(np.random.rand(n, n).astype(dtype)) tol = A.size * jax.numpy.finfo(dtype).eps x0 = jax.numpy.array(np.random.rand(n).astype(dtype)) b = jax.numpy.array(np.random.rand(n), dtype=dtype) r, beta = gmres.gmres_residual(A_mv, [], b, x0) _, V, R, _ = gmres.gmres_krylov(A_mv, [], n_kry, x0, r, beta, tol, jax.numpy.linalg.norm(b), precision) phases = jax.numpy.sign(jax.numpy.diagonal(R[:-1, :])) R = phases.conj()[:, None] * R[:-1, :] Vtest = np.zeros((n, n_kry + 1), dtype=x0.dtype) Vtest[:, 0] = r / beta Vtest = jax.numpy.array(Vtest) Htest = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype) for k in range(n_kry): Vtest, Htest = gmres.kth_arnoldi_step(k, A_mv, [], Vtest, Htest, tol, precision) _, Rtest = jax.numpy.linalg.qr(Htest) phases = jax.numpy.sign(jax.numpy.diagonal(Rtest)) Rtest = phases.conj()[:, None] * Rtest np.testing.assert_allclose(V, Vtest, atol=tol) np.testing.assert_allclose(R, Rtest, atol=tol)
def gmres(self, A_mv: Callable, b: Tensor, A_args: Optional[List] = None, A_kwargs: Optional[dict] = None, x0: Optional[Tensor] = None, tol: float = 1E-05, atol: Optional[float] = None, num_krylov_vectors: Optional[int] = None, maxiter: Optional[int] = 1, M: Optional[Callable] = None) -> Tuple[Tensor, int]: """ GMRES solves the linear system A @ x = b for x given a vector `b` and a general (not necessarily symmetric/Hermitian) linear operator `A`. As a Krylov method, GMRES does not require a concrete matrix representation of the n by n `A`, but only a function `vector1 = A_mv(vector0, *A_args, **A_kwargs)` prescribing a one-to-one linear map from vector0 to vector1 (that is, A must be square, and thus vector0 and vector1 the same size). If `A` is a dense matrix, or if it is a symmetric/Hermitian operator, a different linear solver will usually be preferable. GMRES works by first constructing the Krylov basis K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then solving a certain dense linear system K @ q0 = q1 from whose solution x can be approximated. For `num_krylov_vectors = n` the solution is provably exact in infinite precision, but the expense is cubic in `num_krylov_vectors` so one is typically interested in the `num_krylov_vectors << n` case. The solution can in this case be repeatedly improved, to a point, by restarting the Arnoldi iterations each time `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices balancing expense and accuracy are difficult to predict in advance, so applying this function requires a degree of experimentation. In a tensor network code one is typically interested in A_mv implementing some tensor contraction. This implementation thus allows `b` and `x0` to be of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects. Reshaping to and from a matrix problem is handled internally. The Jax backend version of GMRES uses a homemade implementation that, for now, is suboptimal for num_krylov_vecs ~ b.size. For the same reason as described in eigsh_lancsoz, the function A_mv should be Jittable (or already Jitted) and, if at all possible, defined only once at the global scope. A new compilation will be triggered each time an A_mv with a new function signature is passed in, even if the 'new' function is identical to the old one (function identity is undecidable). Args: A_mv : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and `v` have the same shape. b : The `b` in `A @ x = b`; it should be of the shape `A_mv` operates on. A_args : Positional arguments to `A_mv`, supplied to this interface as a list. Default: None. A_kwargs : In the other backends, keyword arguments to `A_mv`, supplied as a dictionary. However, the Jax backend does not support A_mv accepting keyword arguments since this causes problems with Jit. Therefore, an error is thrown if A_kwargs is specified. Default: None. x0 : An optional guess solution. Zeros are used by default. If `x0` is supplied, its shape and dtype must match those of `b`, or an error will be thrown. Default: zeros. tol, atol: Solution tolerance to achieve, norm(residual) <= max(tol*norm(b), atol). Default: tol=1E-05 atol=tol num_krylov_vectors : Size of the Krylov space to build at each restart. Expense is cubic in this parameter. If supplied, it must be an integer in 0 < num_krylov_vectors <= b.size. Default: b.size. maxiter : The Krylov space will be repeatedly rebuilt up to this many times. Large values of this argument should be used only with caution, since especially for nearly symmetric matrices and small `num_krylov_vectors` convergence might well freeze at a value significantly larger than `tol`. Default: 1 M : Inverse of the preconditioner of A; see the docstring for `scipy.sparse.linalg.gmres`. This is unsupported in the Jax backend, and NotImplementedError will be raised if it is supplied. Default: None. Raises: ValueError: -if `x0` is supplied but its shape differs from that of `b`. -if num_krylov_vectors is 0 or exceeds b.size. -if tol or atol was negative. NotImplementedError: - If M is supplied. - If A_kwargs is supplied. Returns: x : The converged solution. It has the same shape as `b`. info : 0 if convergence was achieved, the number of restarts otherwise. """ if x0 is not None: if x0.shape != b.shape: errstring = ( f"If x0 is supplied, its shape, {x0.shape}, must match b's" f", {b.shape}.") raise ValueError(errstring) if x0.dtype != b.dtype: errstring = ( f"If x0 is supplied, its dtype, {x0.dtype}, must match b's" f", {b.dtype}.") raise ValueError(errstring) x0 = x0.ravel() else: x0 = self.zeros(b.shape, b.dtype).ravel() if num_krylov_vectors is None: num_krylov_vectors = b.size if num_krylov_vectors <= 0 or num_krylov_vectors > b.size: errstring = (f"num_krylov_vectors must be in " f"0 < {num_krylov_vectors} <= {b.size}.") raise ValueError(errstring) if tol < 0: raise ValueError(f"tol = {tol} must be positive.") if atol is None: atol = tol elif atol < 0: raise ValueError(f"atol = {atol} must be positive.") if M is not None: raise NotImplementedError("M is not supported by the Jax backend.") if A_kwargs is not None: raise NotImplementedError( "A_kwargs is not supported by the Jax backend.") if A_args is None: A_args = [] if A_mv not in _CACHED_MATVECS: @libjax.tree_util.Partial def matrix_matvec(x, *args): x = x.reshape(b.shape) result = A_mv(x, *args) return result.ravel() _CACHED_MATVECS[A_mv] = matrix_matvec if "gmres" not in _CACHED_FUNCTIONS: _CACHED_FUNCTIONS["gmres"] = jitted_functions.gmres_wrapper(libjax) gmres_m = _CACHED_FUNCTIONS["gmres"].gmres_m x, _, n_iter, converged = gmres_m(_CACHED_MATVECS[A_mv], A_args, b.ravel(), x0, tol, atol, num_krylov_vectors, maxiter) if converged: info = 0 else: info = n_iter x = self.reshape(x, b.shape) return x, info