def test_grad(self, ad="simple"): call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): return 3. * jnp.sin(2. * x) def f_outside(x): return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x) x = 4. self.assertAllClose(f_jax(x), f_outside(x)) grad_f = api.grad(f_outside)(x) self.assertAllClose(api.grad(f_jax)(x), grad_f)
def test_higher_order_grad(self, degree=4): call_tf = call_tf_full_ad def f_jax(x): return 2. * x * x * x def f_outside(x): return 2. * call_tf(lambda y: y * y * y, x, result_shape=x) grad_jax = f_jax grad_outside = f_outside for i in range(degree): grad_jax = api.grad(grad_jax) grad_outside = api.grad(grad_outside) res_jax = grad_jax(5.) self.assertAllClose(res_jax, grad_outside(5.))
def test_grad_pytree(self): call_tf = call_tf_full_ad def f_jax(xy): dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1]) return 3. * dict_ab["a"] + 4. * dict_ab["b"] def f_outside(xy): dict_ab = call_tf( lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]), xy, result_shape=dict(a=xy[0], b=xy[1])) return 3. * dict_ab["a"] + 4. * dict_ab["b"] xy = (5., 6.) self.assertAllClose(f_jax(xy), f_outside(xy)) res_jax = api.grad(f_jax)(xy) self.assertAllClose(res_jax, api.grad(f_outside)(xy))
def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = jnp.ones((3, 4), jnp.float32) i = jnp.ones((3,), jnp.int32) f = lambda x, i: jnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, np.zeros(i.shape, dtypes.float0))) expected = np.broadcast_to( np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents)
def testGammaGrad(self, alpha): rng = random.PRNGKey(0) alphas = np.full((100,), alpha) z = random.gamma(rng, alphas) actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas) eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps) pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf self.assertAllClose(actual_grad, expected_grad, check_dtypes=True, rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4)
def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = np.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with jax.enable_checks(False): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x)
def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)