Example #1
0
    def test_richardson_solve_precision(self):
        """Check that precision is maintained even across many orders of magnitude."""
        matrix = (jnp.eye(50) -
                  jnp.diag(jnp.full([49], 0.1), -1)).astype("float32")
        b = jnp.array([1.] + [0.] * 49, dtype="float32")
        x_est = linear_solvers.richardson_solve(lambda x: matrix @ x,
                                                b,
                                                iterations=50)
        x_expected = jnp.power(10., -jnp.arange(50)).astype("float32")

        # Both the closed-form and estimated values should round off to zero at the
        # same time; the only limit should be the resolution of float32 and not
        # any inaccuracies during the solve.
        np.testing.assert_allclose(x_est, x_expected, atol=0, rtol=1e-5)
Example #2
0
 def structured_iter_solve(m1, m2, m3, b1, b2):
     return linear_solvers.richardson_solve(functools.partial(
         structured_matvec, m1, m2, m3), (b1, b2),
                                            iterations=200)
Example #3
0
 def iter_solve(matrix, b):
     return linear_solvers.richardson_solve(lambda x: matrix @ x,
                                            b,
                                            iterations=100)