def test_value_and_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x) self.assertEqual(y, jnp.float_power(x, 2)) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o)
def test(remat): x = jnp.array(3.) mod = CountingModule() self.assertEqual(mod.count, 0) f = lambda x: callback(mod(x)) if remat: f = stateful.remat(f) y, g = stateful.value_and_grad(f)(x) np.testing.assert_allclose(y, x ** 2, rtol=1e-3) np.testing.assert_allclose(g, 2 * x, rtol=1e-3) self.assertEqual(mod.count, 1) num_forward = len(forward) num_backward = len(backward) del forward[:], backward[:] return num_forward, num_backward
def f(x): y, g = stateful.value_and_grad(SquareModule())(x) return y, g
def test_value_and_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.value_and_grad(lambda x: x**2)(x)
def test_value_and_grad(self): x = jnp.array(2.) y, g = stateful.value_and_grad(SquareModule())(x) self.assertEqual(y, x ** 2) np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])), (0, x)))), # Automatic differentiation. # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp). ("grad", lambda f: stateful.grad(lambda x: f(x).sum())), ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())), ("checkpoint", stateful.remat), ) # pylint: enable=g-long-lambda class StatefulTest(parameterized.TestCase): @test_utils.transform_and_run def test_grad(self): x = jnp.array(3.) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_grad_no_transform(self): x = jnp.array(3.)