def test_vmap_jit(self): add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0')) x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose(jax.jit(jax.vmap(add_two))(x), x + 2.0)
def test_abstract_eval_simple(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jax.ShapeDtypeStruct((10, ), np.float32) output_shape = jax.eval_shape(add_two, x) assert output_shape.shape == (10, ) assert output_shape.dtype == np.int32
def test_vmap_nonzero_index(self): add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0')) x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose( jax.vmap(add_two, in_axes=1, out_axes=1)(x), x + 2.0)
def test_vmap_unbatched_array(self): add_arrays = primitive( dex.eval( r'\x:((Fin 10)=>Float) y:((Fin 10)=>Float). for i. x.i + y.i')) x = jnp.arange(10, dtype=np.float32) y = jnp.linspace(jnp.arange(10), jnp.arange(10, 20), num=5, dtype=jnp.float32) np.testing.assert_allclose( jax.vmap(add_arrays, in_axes=[None, 0])(x, y), x + y)
def test_grad(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float). ' 'sum $ for i. (i_to_f $ ordinal i) * x.i * x.i')) def f_jax(x): return jnp.sum(jnp.arange(10.) * x**2) x = jnp.linspace(-0.2, 0.5, num=10) grad_dex = jax.grad(f_dex)(x) grad_jax = jax.grad(f_jax)(x) np.testing.assert_allclose(grad_dex, grad_jax)
def test_jvp(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). ' 'for i. x.i * x.i + 2.0 * y.i')) def f_jax(x, y): return x**2 + 2 * y x = jnp.arange(10.) y = jnp.linspace(-0.2, 0.5, num=10) u = jnp.linspace(0.1, 0.3, num=10) v = jnp.linspace(2.0, -5.0, num=10) output_dex, tangent_dex = jax.jvp(f_dex, (x, y), (u, v)) output_jax, tangent_jax = jax.jvp(f_jax, (x, y), (u, v)) np.testing.assert_allclose(output_dex, output_jax) np.testing.assert_allclose(tangent_dex, tangent_jax)
def test_grad_binary_function_jit(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). ' 'sum $ for i. x.i * x.i + 2.0 * y.i')) def f_jax(x, y): return jnp.sum(x**2 + 2 * y) def grad_dex(x, y): return jax.grad(f_dex)(x, y) def grad_jax(x, y): return jax.grad(f_jax)(x, y) x = jnp.arange(10.) y = jnp.linspace(-0.2, 0.5, num=10) np.testing.assert_allclose( jax.jit(grad_dex)(x, y), jax.jit(grad_jax)(x, y))
def test_jit_scale(): scale = primitive( dex.eval(r'\x:((Fin 10)=>Float) y:Float. for i. x.i * y')) x = jnp.arange((10, ), dtype=np.float32) np.testing.assert_allclose(scale(x, 5.0), x * 5.0)
def test_jit_array(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jnp.zeros((10, ), dtype=np.float32) np.testing.assert_allclose(jax.jit(add_two)(x), (x + 2.0).astype(np.int32))
def test_jit_scalar(): add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) x = jnp.zeros((), dtype=np.float32) np.testing.assert_allclose(jax.jit(add_two)(x), 2.0)
def test_impl_array(): add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. x.i + 2.0')) x = jnp.arange((10, ), dtype=np.float32) np.testing.assert_allclose(add_two(x), x + 2.0)
def test_impl_scalar(self): add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) x = jnp.zeros((), dtype=np.float32) np.testing.assert_allclose(add_two(x), x + 2.0)