def test_jvp(self): @djax.djit def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) z, z_dot = jax.jvp(f, (x,), (ones_like(x),)) def g(x): return jnp.sin(x).sum() expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),)) self.assertAllClose(np.array(z), expected_z, check_dtypes=False) self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
def test_linearize(self): @djax.djit def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) with jax.enable_checks(False): # TODO implement dxla_call abs eval rule z, f_lin = jax.linearize(f, x) z_dot = f_lin(ones_like(x)) def g(x): return jnp.sin(x).sum() expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),)) self.assertAllClose(np.array(z), expected_z, check_dtypes=False) self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)