def test_richardson_solve_precision(self): """Check that precision is maintained even across many orders of magnitude.""" matrix = (jnp.eye(50) - jnp.diag(jnp.full([49], 0.1), -1)).astype("float32") b = jnp.array([1.] + [0.] * 49, dtype="float32") x_est = linear_solvers.richardson_solve(lambda x: matrix @ x, b, iterations=50) x_expected = jnp.power(10., -jnp.arange(50)).astype("float32") # Both the closed-form and estimated values should round off to zero at the # same time; the only limit should be the resolution of float32 and not # any inaccuracies during the solve. np.testing.assert_allclose(x_est, x_expected, atol=0, rtol=1e-5)
def structured_iter_solve(m1, m2, m3, b1, b2): return linear_solvers.richardson_solve(functools.partial( structured_matvec, m1, m2, m3), (b1, b2), iterations=200)
def iter_solve(matrix, b): return linear_solvers.richardson_solve(lambda x: matrix @ x, b, iterations=100)