Esempio n. 1
0
    def forward(x):
      def create_and_use_layer(x):
        m = SquareModule(name="layer")
        log.append(m.module_name)
        return m(x)

      if not inline_hk_remat:
        create_and_use_layer = stateful.remat(create_and_use_layer)

      for _ in range(2):
        if inline_hk_remat:
          x = stateful.remat(create_and_use_layer)(x)
        else:
          x = create_and_use_layer(x)
      return x
Esempio n. 2
0
 def test(remat):
   x = jnp.array(3.)
   mod = CountingModule()
   self.assertEqual(mod.count, 0)
   f = lambda x: callback(mod(x))
   if remat:
     f = stateful.remat(f)
   y, g = stateful.value_and_grad(f)(x)
   np.testing.assert_allclose(y, x ** 2, rtol=1e-3)
   np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
   self.assertEqual(mod.count, 1)
   num_forward = len(forward)
   num_backward = len(backward)
   del forward[:], backward[:]
   return num_forward, num_backward
Esempio n. 3
0
 def test_remat_no_transform(self):
   x = jnp.array(3.)
   with self.assertRaises(ValueError, msg="Use jax.remat() instead"):
     stateful.remat(lambda x: x**2)(x)