def test_grad_aux(self): o = object() def f(x): m = SquareModule() return m(x), o x = jnp.array(3.) g, aux = stateful.grad(f, has_aux=True)(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) self.assertIs(aux, o)
def f(x): g = stateful.grad(SquareModule())(x) return g
def test_grad(self): x = jnp.array(3.) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
def test_grad_no_transform(self): x = jnp.array(3.) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.grad(lambda x: x**2)(x)
("cond", lambda f: (lambda x: stateful.cond(True, f, f, x))), ("fori_loop", lambda f: (lambda x: stateful.fori_loop(0, 1, base_test.ignore_index(f), x))), # ("map", lambda f: (lambda x: stateful.map(f, x))), ("scan", lambda f: (lambda x: stateful.scan(base_test.identity_carry(f), None, x))), ("switch", lambda f: (lambda x: stateful.switch(0, [f, f], x))), ("while_loop", lambda f: toggle( f, lambda x: stateful.while_loop(lambda xs: xs[0] == 0, lambda xs: (1, f(xs[1])), (0, x)))), # Automatic differentiation. # TODO(tomhennigan): Add missing features (e.g. custom_vjp, custom_jvp). ("grad", lambda f: stateful.grad(lambda x: f(x).sum())), ("value_and_grad", lambda f: stateful.value_and_grad(lambda x: f(x).sum())), ("checkpoint", stateful.remat), ) # pylint: enable=g-long-lambda class StatefulTest(parameterized.TestCase): @test_utils.transform_and_run def test_grad(self): x = jnp.array(3.) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_grad_no_transform(self):