Esempio n. 1
0
  def testRemainder(self):
    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
Esempio n. 2
0
 def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
   rng = rng_factory(self.rng())
   if jtu.device_under_test() == "tpu" and op is lax.pow:
     raise SkipTest("pow grad imprecise on tpu")
   tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
   args = tuple(rng(shape, dtype) for shape in shapes)
   check_grads(op, args, order, ["fwd", "rev"], tol, tol)
Esempio n. 3
0
  def test_pytree_state(self):
    """Test calling odeint with y(t) values that are pytrees."""
    def dynamics(y, _t):
      return tree_map(jnp.negative, y)

    y0 = (np.array(-0.1), np.array([[[0.1]]]))
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    integrate = partial(odeint, dynamics)
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
Esempio n. 4
0
  def test_weird_time_pendulum_grads(self):
    """Test that gradients are correct when the dynamics depend on t."""
    def dynamics(_np, y, t):
      return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(dynamics, y0, ts, tol=tol)

    integrate = partial(odeint, partial(dynamics, jnp))
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 5
0
  def test_swoop_bigger(self):
    def swoop(_np, y, t, arg1, arg2):
      return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)

    ts = np.array([0.1, 0.2])
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    big_y0 = np.linspace(1.1, 10.9, 10)
    args = (0.1, 0.3)

    self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(swoop, jnp))
    jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 6
0
  def test_pend_grads(self):
    def pend(_np, y, _, m, g):
      theta, omega = y
      return [omega, -m * omega - g * _np.sin(theta)]

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    args = (0.25, 9.8)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(pend, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(pend, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
Esempio n. 7
0
  def test_decay(self):
    def decay(_np, y, t, arg1, arg2):
        return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)


    rng = np.random.RandomState(0)
    args = (rng.randn(3), rng.randn(3))
    y0 = rng.randn(3)
    ts = np.linspace(0.1, 0.2, 4)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(decay, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(decay, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 8
0
    def test_complex_odeint(self):
        # https://github.com/google/jax/issues/3986

        def dy_dt(y, t, alpha):
            return alpha * y

        def f(y0, ts, alpha):
            return odeint(dy_dt, y0, ts, alpha).real

        alpha = 3 + 4j
        y0 = 1 + 2j
        ts = jnp.linspace(0., 1., 11)
        tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

        jtu.check_grads(f, (y0, ts, alpha),
                        modes=["rev"],
                        order=2,
                        atol=tol,
                        rtol=tol)