def test_pytree_state(self): """Test calling odeint with y(t) values that are pytrees.""" def dynamics(y, _t): return tree_map(jnp.negative, y) y0 = (np.array(-0.1), np.array([[[0.1]]])) ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 integrate = partial(odeint, dynamics) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_weird_time_pendulum_grads(self): """Test that gradients are correct when the dynamics depend on t.""" def dynamics(_np, y, t): return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])]) y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(dynamics, y0, ts, tol=tol) integrate = partial(odeint, partial(dynamics, jnp)) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_swoop_bigger(self): def swoop(_np, y, t, arg1, arg2): return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2) ts = np.array([0.1, 0.2]) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 big_y0 = np.linspace(1.1, 10.9, 10) args = (0.1, 0.3) self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol) integrate = partial(odeint, partial(swoop, jnp)) jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_pend_grads(self): def pend(_np, y, _, m, g): theta, omega = y return [omega, -m * omega - g * _np.sin(theta)] y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) args = (0.25, 9.8) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(pend, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(pend, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_complex_odeint(self): # https://github.com/google/jax/issues/3986 def dy_dt(y, t, alpha): return alpha * y def f(y0, ts, alpha): return odeint(dy_dt, y0, ts, alpha).real alpha = 3 + 4j y0 = 1 + 2j ts = jnp.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_decay(self): def decay(_np, y, t, arg1, arg2): return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2) rng = self.rng() args = (rng.randn(3), rng.randn(3)) y0 = rng.randn(3) ts = np.linspace(0.1, 0.2, 4) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(decay, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(decay, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_complex_odeint(self): # https://github.com/google/jax/issues/3986 # https://github.com/google/jax/issues/8757 def dy_dt(y, t, alpha): return alpha * y * jnp.exp(-t).astype(y.dtype) def f(y0, ts, alpha): return odeint(dy_dt, y0, ts, alpha).real alpha = 3 + 4j y0 = 1 + 2j ts = jnp.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 # During the backward pass, this ravels all parameters into a single array # such that dtype promotion is unavoidable. with jax.numpy_dtype_promotion('standard'): jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)