def forward(x): def create_and_use_layer(x): m = SquareModule(name="layer") log.append(m.module_name) return m(x) if not inline_hk_remat: create_and_use_layer = stateful.remat(create_and_use_layer) for _ in range(2): if inline_hk_remat: x = stateful.remat(create_and_use_layer)(x) else: x = create_and_use_layer(x) return x
def test(remat): x = jnp.array(3.) mod = CountingModule() self.assertEqual(mod.count, 0) f = lambda x: callback(mod(x)) if remat: f = stateful.remat(f) y, g = stateful.value_and_grad(f)(x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) self.assertEqual(mod.count, 1) num_forward = len(forward) num_backward = len(backward) del forward[:], backward[:] return num_forward, num_backward
def test_remat_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.remat() instead"): stateful.remat(lambda x: x**2)(x)