예제 #1
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_pytree_state(self):
    """Test calling odeint with y(t) values that are pytrees."""
    def dynamics(y, _t):
      return tree_map(jnp.negative, y)

    y0 = (np.array(-0.1), np.array([[[0.1]]]))
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    integrate = partial(odeint, dynamics)
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
예제 #2
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_weird_time_pendulum_grads(self):
    """Test that gradients are correct when the dynamics depend on t."""
    def dynamics(_np, y, t):
      return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(dynamics, y0, ts, tol=tol)

    integrate = partial(odeint, partial(dynamics, jnp))
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
예제 #3
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_swoop_bigger(self):
    def swoop(_np, y, t, arg1, arg2):
      return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)

    ts = np.array([0.1, 0.2])
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    big_y0 = np.linspace(1.1, 10.9, 10)
    args = (0.1, 0.3)

    self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(swoop, jnp))
    jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
예제 #4
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_pend_grads(self):
    def pend(_np, y, _, m, g):
      theta, omega = y
      return [omega, -m * omega - g * _np.sin(theta)]

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    args = (0.25, 9.8)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(pend, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(pend, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
예제 #5
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_complex_odeint(self):
    # https://github.com/google/jax/issues/3986

    def dy_dt(y, t, alpha):
      return alpha * y

    def f(y0, ts, alpha):
      return odeint(dy_dt, y0, ts, alpha).real

    alpha = 3 + 4j
    y0 = 1 + 2j
    ts = jnp.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
예제 #6
0
파일: ode_test.py 프로젝트: ahoenselaar/jax
  def test_decay(self):
    def decay(_np, y, t, arg1, arg2):
        return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)


    rng = self.rng()
    args = (rng.randn(3), rng.randn(3))
    y0 = rng.randn(3)
    ts = np.linspace(0.1, 0.2, 4)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(decay, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(decay, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
예제 #7
0
파일: ode_test.py 프로젝트: xueeinstein/jax
    def test_complex_odeint(self):
        # https://github.com/google/jax/issues/3986
        # https://github.com/google/jax/issues/8757

        def dy_dt(y, t, alpha):
            return alpha * y * jnp.exp(-t).astype(y.dtype)

        def f(y0, ts, alpha):
            return odeint(dy_dt, y0, ts, alpha).real

        alpha = 3 + 4j
        y0 = 1 + 2j
        ts = jnp.linspace(0., 1., 11)
        tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

        # During the backward pass, this ravels all parameters into a single array
        # such that dtype promotion is unavoidable.
        with jax.numpy_dtype_promotion('standard'):
            jtu.check_grads(f, (y0, ts, alpha),
                            modes=["rev"],
                            order=2,
                            atol=tol,
                            rtol=tol)