コード例 #1
0
def test_gmres_arnoldi_step(dtype):
  """
  The Arnoldi decomposition within GMRES is correct.
  """
  gmres = jitted_functions.gmres_wrapper(jax)
  dummy = jax.numpy.zeros(1, dtype=dtype)
  dtype = dummy.dtype
  n = 4
  n_kry = n
  np.random.seed(10)
  A = jax.numpy.array(np.random.rand(n, n).astype(dtype))
  x0 = jax.numpy.array(np.random.rand(n).astype(dtype))
  Q = np.zeros((n, n_kry + 1), dtype=x0.dtype)
  Q[:, 0] = x0/jax.numpy.linalg.norm(x0)
  Q = jax.numpy.array(Q)
  H = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype)
  tol = A.size*jax.numpy.finfo(dtype).eps

  @jax.tree_util.Partial
  def A_mv(x):
    return A @ x
  for k in range(n_kry):
    Q, H = gmres.kth_arnoldi_step(k, A_mv, [], Q, H, tol)
  QAQ = Q[:, :n_kry].conj().T @ A @ Q[:, :n_kry]
  np.testing.assert_allclose(H[:n_kry, :], QAQ, atol=tol)
コード例 #2
0
def test_givens(dtype):
  """
  gmres["givens_rotation produces the correct rotation factors.
  """
  gmres = jitted_functions.gmres_wrapper(jax)
  np.random.seed(10)
  v = jax.numpy.array(np.random.rand(2).astype(dtype))
  cs, sn = gmres.givens_rotation(*v)
  rot = np.zeros((2, 2), dtype=dtype)
  rot[0, 0] = cs
  rot[1, 1] = cs
  rot[0, 1] = -sn
  rot[1, 0] = sn
  rot = jax.numpy.array(rot)
  result = rot @ v
  tol = 4*jax.numpy.finfo(dtype).eps
  np.testing.assert_allclose(result[-1], 0., atol=tol)
コード例 #3
0
def test_gs(dtype):
  """
  The Gram-Schmidt process works.
  """
  gmres = jitted_functions.gmres_wrapper(jax)
  dummy = jax.numpy.zeros(1, dtype=dtype)
  dtype = dummy.dtype
  n = 8
  A = np.zeros((n, 2), dtype=dtype)
  A[:-1, 0] = 1.0
  Ai = A[:, 0] / np.linalg.norm(A[:, 0])
  A[:, 0] = Ai
  A[-1, -1] = 1.0
  A = jax.numpy.array(A)

  x0 = jax.numpy.array(np.random.rand(n).astype(dtype))
  v_new, _ = jax.lax.scan(gmres.gs_step, x0, xs=A.T)
  dotcheck = v_new @ A
  tol = A.size*jax.numpy.finfo(dtype).eps
  np.testing.assert_allclose(dotcheck, np.zeros(2), atol=tol)
コード例 #4
0
def test_gmres_on_small_known_problem(dtype):
  """
  GMRES produces the correct result on an analytically solved
  linear system.
  """
  dummy = jax.numpy.zeros(1, dtype=dtype)
  dtype = dummy.dtype
  gmres = jitted_functions.gmres_wrapper(jax)

  A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype)
  b = jax.numpy.array([3, 2], dtype=dtype)
  x0 = jax.numpy.ones(2, dtype=dtype)
  n_kry = 2
  maxiter = 1

  @jax.tree_util.Partial
  def A_mv(x):
    return A @ x
  tol = A.size*jax.numpy.finfo(dtype).eps
  x, _, _, _ = gmres.gmres_m(A_mv, [], b, x0, tol, tol, n_kry, maxiter)
  solution = jax.numpy.array([2., 1.], dtype=dtype)
  np.testing.assert_allclose(x, solution, atol=tol)
コード例 #5
0
def test_gmres_krylov(dtype):
    """
  gmres_krylov correctly builds the QR-decomposed Arnoldi decomposition.
  This function assumes that gmres["kth_arnoldi_step (which is
  independently tested) is correct.
  """
    dummy = jax.numpy.zeros(1, dtype=dtype)
    dtype = dummy.dtype
    gmres = jitted_functions.gmres_wrapper(jax)

    n = 2
    n_kry = n
    np.random.seed(10)

    @jax.tree_util.Partial
    def A_mv(x):
        return A @ x

    A = jax.numpy.array(np.random.rand(n, n).astype(dtype))
    tol = A.size * jax.numpy.finfo(dtype).eps
    x0 = jax.numpy.array(np.random.rand(n).astype(dtype))
    b = jax.numpy.array(np.random.rand(n), dtype=dtype)
    r, beta = gmres.gmres_residual(A_mv, [], b, x0)
    _, V, R, _ = gmres.gmres_krylov(A_mv, [], n_kry, x0, r, beta, tol,
                                    jax.numpy.linalg.norm(b), precision)
    phases = jax.numpy.sign(jax.numpy.diagonal(R[:-1, :]))
    R = phases.conj()[:, None] * R[:-1, :]
    Vtest = np.zeros((n, n_kry + 1), dtype=x0.dtype)
    Vtest[:, 0] = r / beta
    Vtest = jax.numpy.array(Vtest)
    Htest = jax.numpy.zeros((n_kry + 1, n_kry), dtype=x0.dtype)
    for k in range(n_kry):
        Vtest, Htest = gmres.kth_arnoldi_step(k, A_mv, [], Vtest, Htest, tol,
                                              precision)
    _, Rtest = jax.numpy.linalg.qr(Htest)
    phases = jax.numpy.sign(jax.numpy.diagonal(Rtest))
    Rtest = phases.conj()[:, None] * Rtest
    np.testing.assert_allclose(V, Vtest, atol=tol)
    np.testing.assert_allclose(R, Rtest, atol=tol)
コード例 #6
0
    def gmres(self,
              A_mv: Callable,
              b: Tensor,
              A_args: Optional[List] = None,
              A_kwargs: Optional[dict] = None,
              x0: Optional[Tensor] = None,
              tol: float = 1E-05,
              atol: Optional[float] = None,
              num_krylov_vectors: Optional[int] = None,
              maxiter: Optional[int] = 1,
              M: Optional[Callable] = None) -> Tuple[Tensor, int]:
        """ GMRES solves the linear system A @ x = b for x given a vector `b` and
    a general (not necessarily symmetric/Hermitian) linear operator `A`.

    As a Krylov method, GMRES does not require a concrete matrix representation
    of the n by n `A`, but only a function
    `vector1 = A_mv(vector0, *A_args, **A_kwargs)`
    prescribing a one-to-one linear map from vector0 to vector1 (that is,
    A must be square, and thus vector0 and vector1 the same size). If `A` is a
    dense matrix, or if it is a symmetric/Hermitian operator, a different
    linear solver will usually be preferable.

    GMRES works by first constructing the Krylov basis
    K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then
    solving a certain dense linear system K @ q0 = q1 from whose solution x can
    be approximated. For `num_krylov_vectors = n` the solution is provably exact
    in infinite precision, but the expense is cubic in `num_krylov_vectors` so
    one is typically interested in the `num_krylov_vectors << n` case.
    The solution can in this case be repeatedly
    improved, to a point, by restarting the Arnoldi iterations each time
    `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices
    balancing expense and accuracy are difficult to predict in advance, so
    applying this function requires a degree of experimentation.

    In a tensor network code one is typically interested in A_mv implementing
    some tensor contraction. This implementation thus allows `b` and `x0` to be
    of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.
    Reshaping to and from a matrix problem is handled internally.

    The Jax backend version of GMRES uses a homemade implementation that, for
    now, is suboptimal for num_krylov_vecs ~ b.size.

    For the same reason as described in eigsh_lancsoz, the function A_mv
    should be Jittable (or already Jitted) and, if at all possible, defined
    only once at the global scope. A new compilation will be triggered each
    time an A_mv with a new function signature is passed in, even if the
    'new' function is identical to the old one (function identity is
    undecidable).


    Args:
      A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and
                 `v` have the same shape.
      b        : The `b` in `A @ x = b`; it should be of the shape `A_mv`
                 operates on.
      A_args   : Positional arguments to `A_mv`, supplied to this interface
                 as a list.
                 Default: None.
      A_kwargs : In the other backends, keyword arguments to `A_mv`, supplied
                 as a dictionary. However, the Jax backend does not support
                 A_mv accepting
                 keyword arguments since this causes problems with Jit.
                 Therefore, an error is thrown if A_kwargs is specified.
                 Default: None.
      x0       : An optional guess solution. Zeros are used by default.
                 If `x0` is supplied, its shape and dtype must match those of
                 `b`, or an
                 error will be thrown.
                 Default: zeros.
      tol, atol: Solution tolerance to achieve,
                 norm(residual) <= max(tol*norm(b), atol).
                 Default: tol=1E-05
                          atol=tol
      num_krylov_vectors
               : Size of the Krylov space to build at each restart.
                 Expense is cubic in this parameter. If supplied, it must be
                 an integer in 0 < num_krylov_vectors <= b.size.
                 Default: b.size.
      maxiter  : The Krylov space will be repeatedly rebuilt up to this many
                 times. Large values of this argument
                 should be used only with caution, since especially for nearly
                 symmetric matrices and small `num_krylov_vectors` convergence
                 might well freeze at a value significantly larger than `tol`.
                 Default: 1
      M        : Inverse of the preconditioner of A; see the docstring for
                 `scipy.sparse.linalg.gmres`. This is unsupported in the Jax
                 backend, and NotImplementedError will be raised if it is
                 supplied.
                 Default: None.


    Raises:
      ValueError: -if `x0` is supplied but its shape differs from that of `b`.
                  -if num_krylov_vectors is 0 or exceeds b.size.
                  -if tol or atol was negative.
      NotImplementedError: - If M is supplied.
                           - If A_kwargs is supplied.

    Returns:
      x       : The converged solution. It has the same shape as `b`.
      info    : 0 if convergence was achieved, the number of restarts otherwise.
    """

        if x0 is not None:
            if x0.shape != b.shape:
                errstring = (
                    f"If x0 is supplied, its shape, {x0.shape}, must match b's"
                    f", {b.shape}.")
                raise ValueError(errstring)
            if x0.dtype != b.dtype:
                errstring = (
                    f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"
                    f", {b.dtype}.")
                raise ValueError(errstring)
            x0 = x0.ravel()
        else:
            x0 = self.zeros(b.shape, b.dtype).ravel()

        if num_krylov_vectors is None:
            num_krylov_vectors = b.size
        if num_krylov_vectors <= 0 or num_krylov_vectors > b.size:
            errstring = (f"num_krylov_vectors must be in "
                         f"0 < {num_krylov_vectors} <= {b.size}.")
            raise ValueError(errstring)

        if tol < 0:
            raise ValueError(f"tol = {tol} must be positive.")

        if atol is None:
            atol = tol
        elif atol < 0:
            raise ValueError(f"atol = {atol} must be positive.")

        if M is not None:
            raise NotImplementedError("M is not supported by the Jax backend.")
        if A_kwargs is not None:
            raise NotImplementedError(
                "A_kwargs is not supported by the Jax backend.")

        if A_args is None:
            A_args = []

        if A_mv not in _CACHED_MATVECS:

            @libjax.tree_util.Partial
            def matrix_matvec(x, *args):
                x = x.reshape(b.shape)
                result = A_mv(x, *args)
                return result.ravel()

            _CACHED_MATVECS[A_mv] = matrix_matvec

        if "gmres" not in _CACHED_FUNCTIONS:
            _CACHED_FUNCTIONS["gmres"] = jitted_functions.gmres_wrapper(libjax)
        gmres_m = _CACHED_FUNCTIONS["gmres"].gmres_m
        x, _, n_iter, converged = gmres_m(_CACHED_MATVECS[A_mv], A_args,
                                          b.ravel(), x0, tol, atol,
                                          num_krylov_vectors, maxiter)
        if converged:
            info = 0
        else:
            info = n_iter
        x = self.reshape(x, b.shape)
        return x, info