def testRemainder(self): rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol)
def test_pytree_state(self): """Test calling odeint with y(t) values that are pytrees.""" def dynamics(y, _t): return tree_map(jnp.negative, y) y0 = (np.array(-0.1), np.array([[[0.1]]])) ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 integrate = partial(odeint, dynamics) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_weird_time_pendulum_grads(self): """Test that gradients are correct when the dynamics depend on t.""" def dynamics(_np, y, t): return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])]) y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(dynamics, y0, ts, tol=tol) integrate = partial(odeint, partial(dynamics, jnp)) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_swoop_bigger(self): def swoop(_np, y, t, arg1, arg2): return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2) ts = np.array([0.1, 0.2]) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 big_y0 = np.linspace(1.1, 10.9, 10) args = (0.1, 0.3) self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol) integrate = partial(odeint, partial(swoop, jnp)) jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_pend_grads(self): def pend(_np, y, _, m, g): theta, omega = y return [omega, -m * omega - g * _np.sin(theta)] y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) args = (0.25, 9.8) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(pend, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(pend, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_decay(self): def decay(_np, y, t, arg1, arg2): return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2) rng = np.random.RandomState(0) args = (rng.randn(3), rng.randn(3)) y0 = rng.randn(3) ts = np.linspace(0.1, 0.2, 4) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(decay, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(decay, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_complex_odeint(self): # https://github.com/google/jax/issues/3986 def dy_dt(y, t, alpha): return alpha * y def f(y0, ts, alpha): return odeint(dy_dt, y0, ts, alpha).real alpha = 3 + 4j y0 = 1 + 2j ts = jnp.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)