def check_grads(f, args, order, atol=None, rtol=None, eps=None): # TODO(mattjj,dougalm): add higher-order check default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2 atol = atol or default_tol rtol = rtol or default_tol eps = eps or default_tol jtu.check_jvp(f, partial(api.jvp, f), args, atol, rtol, eps) jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
def testQr(self, shape, dtype, full_matrices, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if (np.issubdtype(dtype, onp.complexfloating) and (jtu.device_under_test() == "tpu" or jax.lib.version <= (0, 1, 27))): raise unittest.SkipTest("No complex QR implementation") m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support batch dimensions. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * np.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue( onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ), atol=3e-3)
def testQr(self, shape, dtype, full_matrices, rng): m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support broadcasting. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ))
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) if b != () and jax.lib.version <= (0, 1, 28): raise unittest.SkipTest("Batched SVD requires jaxlib 0.1.29") args_maker = lambda: [rng(b + (m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue(onp.all( norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50)) else: self.assertTrue(onp.all( norm(a - onp.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350)) else: self.assertTrue(onp.all( norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2])) < 300)) # Check the unitary properties of the singular vector matrices. self.assertTrue(onp.all(norm(onp.eye(out[0].shape[-1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-2]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) if not full_matrices: svd = partial(np.linalg.svd, full_matrices=False) jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-2, atol=1e-1)
def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50)) # Check the unitary properties of the singular vector matrices. self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) if not full_matrices: svd = partial(np.linalg.svd, full_matrices=False) jtu.check_jvp(svd, partial(jvp, svd), (a,), atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL)
def test_jvp_linearized(self, f, args): print(f) jtu.check_jvp(f, partial(jvp_unlinearized, f), args)
def test_jvp(self, f, args): print(f) jtu.check_jvp(f, partial(jvp, f), args)
def test_jvp_linearized(self, f, args): jtu.check_jvp(f, partial(jvp_unlinearized, f), args, rtol={onp.float32: 3e-2})
def test_jvp(self, f, args): jtu.check_jvp(f, partial(jvp, f), args, rtol={onp.float32: 3e-2})
import pdb # Parameters for the test function #TODO: Test more functions D = 4 t0 = 0.1 t1 = 0.11 y0 = np.linspace(0.1, 0.9, D) fargs = (0.1, 0.2) def f(y, t, (arg1, arg2)): return -np.sqrt(t) - np.sin(np.dot(y, arg1)) - np.mean((y + arg2)**2) def test_odeint_jvp(): def odeint2(y0, t0, t1, fargs): return odeint(y0, np.array([t0, t1]), fargs, func=f, atol=1e-8, rtol=1e-8) def odeint2_jvp((y0, t0, t1, fargs), (tan_y, tan_t0, tan_t1, tan_fargs)): return jvp_odeint((y0, np.array([t0, t1]), fargs), (tan_y, np.array([tan_t0, tan_t1]), tan_fargs), func=f) check_jvp(odeint2, odeint2_jvp, (y0, t0, t1, fargs))
y0 = np.linspace(0.1, 0.9, D) arg = np.zeros((0, )) def f(y, t, args): return -np.sqrt(t) - y @custom_transforms def onearg_odeint(y0): return odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)[1] def onearg_jvp((y0, arg), (tangent_all, )): return jvp_odeint(tangent_all, f, y0, t0, t1, arg) ad.defjvp(onearg_odeint.primitive, onearg_jvp) check_jvp(onearg_odeint, onearg_jvp, (y0, arg)) def test_odeint_jvp_all(): D = 10 t0 = 0.1 t1 = 0.2 y0 = np.linspace(0.1, 0.9, D) fargs = (0.1, 0.2) def f(y, t, arg1, arg2): return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2) @custom_transforms def twoarg_odeint(y0, args): return odeint(f,
from jax.config import config config.update("jax_enable_x64", True) from jax.test_util import check_jvp from jaxsde import ito_integrate, jvp_ito_integrate from test_sdeint import make_example_sde def test_ito_int_jvp(): # forward mode f, g, b, y0, ts, dt = make_example_sde() def onearg_int(y0): return ito_integrate(f, g, y0, ts, b, dt) def odeint2_jvp((y0, ), (tan_y, )): return jvp_ito_integrate(tan_y, y0, f, g, ts, b, dt=dt, args=()) check_jvp(onearg_int, odeint2_jvp, (y0, ), atol=1e-3, rtol=1e-3)