def test_linearize(self): @djax.djit def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) with jax.enable_checks(False): # TODO implement dxla_call abs eval rule z, f_lin = jax.linearize(f, x) z_dot = f_lin(ones_like(x)) def g(x): return jnp.sin(x).sum() expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),)) self.assertAllClose(np.array(z), expected_z, check_dtypes=False) self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = np.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with jax.enable_checks(False): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x)
def testHardTanhMemory(self): # see https://github.com/google/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom