def test_custom_linear_solve_errors(self): solve = lambda f, x: x with self.assertRaisesRegex(TypeError, re.escape("matvec() output pytree")): lax.custom_linear_solve(lambda x: [x], 1.0, solve, solve) with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")): lax.custom_linear_solve(lambda x: x, 1.0, lambda f, x: [x], solve) with self.assertRaisesRegex( TypeError, re.escape("transpose_solve() output pytree")): lax.custom_linear_solve(lambda x: x, 1.0, solve, lambda f, x: [x]) with self.assertRaisesRegex(ValueError, re.escape("solve() output shapes")): lax.custom_linear_solve(lambda x: x, 1.0, lambda f, x: np.ones(2), solve) def bad_matvec_usage(a): return lax.custom_linear_solve(lambda x: a * np.ones(2), 1.0, solve, solve) with self.assertRaisesRegex(ValueError, re.escape("matvec() output shapes")): api.jvp(bad_matvec_usage, (1.0, ), (1.0, ))
def matrix_free_solve_aux(matvec, b): return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux, symmetric=True, has_aux=True)
def solve(a, b): def solve(matvec, x): return jsp.linalg.solve(a, x) def tr_solve(matvec, x): return jsp.linalg.solve(a.T, x) matvec = partial(high_precision_dot, a) return lax.custom_linear_solve(matvec, b, solve, tr_solve)
def pos_def_solve(g, b): # prune aux to allow use as tangent_solve cho_solve_noaux = lambda f, b: cho_solve(f, b)[0] return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True)
def positive_definite_solve(a, b): factors = jsp.linalg.cho_factor(a) def solve(matvec, x): return jsp.linalg.cho_solve(factors, x) matvec = partial(high_precision_dot, a) return lax.custom_linear_solve(matvec, b, solve, symmetric=True)
def linear_solve(a, b): a_factors = jsp.linalg.lu_factor(a) at_factors = jsp.linalg.lu_factor(a.T) def solve(matvec, x): return jsp.linalg.lu_solve(a_factors, x) def transpose_solve(vecmat, x): return jsp.linalg.lu_solve(at_factors, x) return lax.custom_linear_solve( partial(high_precision_dot, a), b, solve, transpose_solve)
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, check_symmetric=False): if x0 is None: x0 = tree_map(jnp.zeros_like, b) b, x0 = device_put((b, x0)) if maxiter is None: size = sum(bi.size for bi in tree_leaves(b)) maxiter = 10 * size # copied from scipy if M is None: M = _identity A = _normalize_matvec(A) M = _normalize_matvec(M) if tree_structure(x0) != tree_structure(b): raise ValueError('x0 and b must have matching tree structure: ' f'{tree_structure(x0)} vs {tree_structure(b)}') if _shapes(x0) != _shapes(b): raise ValueError('arrays in x0 and b must have matching shapes: ' f'{_shapes(x0)} vs {_shapes(b)}') isolve_solve = partial(_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M) # real-valued positive-definite linear operators are symmetric def real_valued(x): return not issubclass(x.dtype.type, np.complexfloating) symmetric = all(map(real_valued, tree_leaves(b))) \ if check_symmetric else False x = lax.custom_linear_solve(A, b, solve=isolve_solve, transpose_solve=isolve_solve, symmetric=symmetric) info = None return x, info
def matrix_free_solve(matvec, b): return lax.custom_linear_solve(matvec, b, richardson_iteration, richardson_iteration)
def matrix_free_solve(matvec, b): return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve, explicit_jacobian_solve)
def pos_def_solve(g, b): return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, M=None, solve_method='batched'): """ GMRES solves the linear system A x = b for x, given A and b. A is specified as a function performing A(vi) -> vf = A @ vi, and in principle need not have any particular special properties, such as symmetry. However, convergence is often slow for nearly symmetric operators. Parameters ---------- A: ndarray or function 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same structure and shape as its argument. b : array or tree of arrays Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape. Returns ------- x : array or tree of arrays The converged solution. Has the same structure as ``b``. info : None Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy. Other Parameters ---------------- x0 : array, optional Starting guess for the solution. Must have the same structure as ``b``. If this is unspecified, zeroes are used. tol, atol : float, optional Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not implement SciPy's "legacy" behavior, so JAX's tolerance will differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``. restart : integer, optional Size of the Krylov subspace ("number of iterations") built between restarts. GMRES works by approximating the true solution x as its projection into a Krylov space of this dimension - this parameter therefore bounds the maximum accuracy achievable from any guess solution. Larger values increase both number of iterations and iteration cost, but may be necessary for convergence. The algorithm terminates early if convergence is achieved before the full subspace is built. Default is 20. maxiter : integer Maximum number of times to rebuild the size-``restart`` Krylov space starting from the solution found at the last iteration. If GMRES halts or is very slow, decreasing this parameter may help. Default is infinite. M : ndarray or function Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. solve_method : 'incremental' or 'batched' The 'incremental' solve method builds a QR decomposition for the Krylov subspace incrementally during the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of the residual norm that allows for early termination within a single "restart". In contrast, the 'batched' solve method solves the least squares problem from scratch at the end of each GMRES iteration. It does not allow for early termination, but has much less overhead on GPUs. See also -------- scipy.sparse.linalg.gmres jax.lax.custom_linear_solve """ if x0 is None: x0 = tree_map(jnp.zeros_like, b) if M is None: M = _identity A = _normalize_matvec(A) M = _normalize_matvec(M) b, x0 = device_put((b, x0)) size = sum(bi.size for bi in tree_leaves(b)) if maxiter is None: maxiter = 10 * size # copied from scipy restart = min(restart, size) if tree_structure(x0) != tree_structure(b): raise ValueError('x0 and b must have matching tree structure: ' f'{tree_structure(x0)} vs {tree_structure(b)}') b_norm = _norm(b) atol = jnp.maximum(tol * b_norm, atol) Mb = M(b) Mb_norm = _norm(Mb) ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm) if solve_method == 'incremental': gmres_func = _gmres_incremental elif solve_method == 'batched': gmres_func = _gmres_batched else: raise ValueError( f"invalid solve_method {solve_method}, must be either " "'incremental' or 'batched'") def _solve(A, b): return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func) x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve) failed = jnp.isnan(_norm(x)) info = jnp.where(failed, x=-1, y=0) return x, info
def bad_matvec_usage(a): return lax.custom_linear_solve( lambda x: a * np.ones(2), 1.0, solve, solve)
def positive_definive_solve(a, b): factors = jsp.linalg.cho_factor(a) def solve(matvec, x): return jsp.linalg.cho_solve(factors, x) return lax.custom_linear_solve( partial(np.dot, a), b, solve, symmetric=True)
def custom_unrolled_lower_tri_solve(mat, b): return lax.custom_linear_solve( partial(unrolled_matvec, mat), b, partial(unrolled_substitution_solve, lower_tri=True), partial(unrolled_substitution_solve, lower_tri=False))
def _gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, M=None, qr_mode=False): """ GMRES solves the linear system A x = b for x, given A and b. A is specified as a function performing A(vi) -> vf = A @ vi, and in principle need not have any particular special properties, such as symmetry. However, convergence is often slow for nearly symmetric operators. Parameters ---------- A: function Function that calculates the linear map (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same structure and shape as its argument. b : array or tree of arrays Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape. Returns ------- x : array or tree of arrays The converged solution. Has the same structure as ``b``. info : None Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy. Other Parameters ---------------- x0 : array, optional Starting guess for the solution. Must have the same structure as ``b``. If this is unspecified, zeroes are used. tol, atol : float, optional Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not implement SciPy's "legacy" behavior, so JAX's tolerance will differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``. restart : integer, optional Size of the Krylov subspace ("number of iterations") built between restarts. GMRES works by approximating the true solution x as its projection into a Krylov space of this dimension - this parameter therefore bounds the maximum accuracy achievable from any guess solution. Larger values increase both number of iterations and iteration cost, but may be necessary for convergence. The algorithm terminates early if convergence is achieved before the full subspace is built. Default is 20. maxiter : integer Maximum number of times to rebuild the size-``restart`` Krylov space starting from the solution found at the last iteration. If GMRES halts or is very slow, decreasing this parameter may help. Default is infinite. M : function Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. qr_mode : bool If True, the algorithm builds an internal Krylov subspace using a QR based algorithm, which reduces overhead and improved stability. However, it may degrade performance significantly on GPUs or TPUs, in which case this flag should be set False. See also -------- scipy.sparse.linalg.gmres jax.lax.custom_linear_solve """ if x0 is None: x0 = tree_map(jnp.zeros_like, b) if M is None: M = _identity b, x0 = device_put((b, x0)) size = sum(bi.size for bi in tree_leaves(b)) if maxiter is None: maxiter = 10 * size # copied from scipy restart = min(restart, size) if tree_structure(x0) != tree_structure(b): raise ValueError('x0 and b must have matching tree structure: ' f'{tree_structure(x0)} vs {tree_structure(b)}') b_norm = _norm_tree(b) if b_norm == 0: return b, 0 outer_tol = jnp.maximum(tol * b_norm, atol) Mb = M(b) Mb_norm = _norm_tree(Mb) inner_tol = Mb_norm * min(1.0, outer_tol / b_norm) if qr_mode: def _solve(A, b): return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M, _gmres_plain) else: def _solve(A, b): return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M, _gmres_qr) x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve) failed = jnp.isnan(_norm_tree(x)) info = jnp.where(failed, x=-1, y=0) return x, info
def loss(a, b): matvec = partial(high_precision_dot, a) x = lax.custom_linear_solve(matvec, b, explicit_jacobian_solve) return np.sum(x)
def f1(): return lax.custom_linear_solve(err, b, solve)
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): """Use Conjugate Gradient iteration to solve ``Ax = b``. The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to numerical precision), but note that the interface is slightly different: you need to supply the linear operator ``A`` as a function instead of a sparse matrix or ``LinearOperator``. Derivatives of ``cg`` are implemented via implicit differentiation with another ``cg`` solve, rather than by differentiating *through* the solver. They will be accurate only if both solves converge. Parameters ---------- A : function Function that calculates the matrix-vector product ``Ax`` when called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and must return array(s) with the same structure and shape as its argument. b : array or tree of arrays Right hand side of the linear system representing a single vector. Can be stored as an array or Python container of array(s) with any shape. Returns ------- x : array or tree of arrays The converged solution. Has the same structure as ``b``. info : None Placeholder for convergence information. In the future, JAX will report the number of iterations when convergence is not achieved, like SciPy. Other Parameters ---------------- x0 : array Starting guess for the solution. Must have the same structure as ``b``. tol, atol : float, optional Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not implement SciPy's "legacy" behavior, so JAX's tolerance will differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. maxiter : integer Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. M : function Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. See also -------- scipy.sparse.linalg.cg jax.lax.custom_linear_solve """ if x0 is None: x0 = tree_map(jnp.zeros_like, b) b, x0 = device_put((b, x0)) if maxiter is None: size = sum(bi.size for bi in tree_leaves(b)) maxiter = 10 * size # copied from scipy if M is None: M = _identity shape = partial(tree_map, lambda x: x.shape) if shape(x0) != shape(b): raise ValueError( f'x0 and b must have matching shape: {shape(x0)} vs {shape(b)}') cg_solve = partial(_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M) # real-valued positive-definite linear operators are symmetric real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating) symmetric = all(map(real_valued, tree_leaves(b))) x = lax.custom_linear_solve(A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric) info = None # TODO(shoyer): return the real iteration count here return x, info
def f2(): return lax.custom_linear_solve(matvec, b, err)