Ejemplo n.º 1
0
  def test_value_and_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x)
    self.assertEqual(y, jnp.float_power(x, 2))
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)
Ejemplo n.º 2
0
 def test(remat):
   x = jnp.array(3.)
   mod = CountingModule()
   self.assertEqual(mod.count, 0)
   f = lambda x: callback(mod(x))
   if remat:
     f = stateful.remat(f)
   y, g = stateful.value_and_grad(f)(x)
   np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
   np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
   self.assertEqual(mod.count, 1)
   num_forward = len(forward)
   num_backward = len(backward)
   del forward[:], backward[:]
   return num_forward, num_backward
Ejemplo n.º 3
0
 def f(x):
   y, g = stateful.value_and_grad(SquareModule())(x)
   return y, g
Ejemplo n.º 4
0
 def test_value_and_grad_no_transform(self):
   x = jnp.array(3.)
   with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
     stateful.value_and_grad(lambda x: x**2)(x)
Ejemplo n.º 5
0
 def test_value_and_grad(self):
   x = jnp.array(2.)
   y, g = stateful.value_and_grad(SquareModule())(x)
   self.assertEqual(y, x ** 2)
   np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
Ejemplo n.º 6
0
    ("fori_loop", lambda f:
     (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))),
    # ("map", lambda f: (lambda x: stateful.map(f, x))),
    ("scan", lambda f:
     (lambda x: stateful.scan(base_test.identity_carry(f), None, x))),
    ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))),
    ("while_loop", lambda f: toggle(
        f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0,
                                         lambda xs: (1, f(xs[1])),
                                         (0, x)))),

    # Automatic differentiation.
    # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp).

    ("grad", lambda f: stateful.grad(lambda x: f(x).sum())),
    ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())),
    ("checkpoint", stateful.remat),
)
# pylint: enable=g-long-lambda


class StatefulTest(parameterized.TestCase):

  @test_utils.transform_and_run
  def test_grad(self):
    x = jnp.array(3.)
    g = stateful.grad(SquareModule())(x)
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

  def test_grad_no_transform(self):
    x = jnp.array(3.)