def test_function_call(): m = dex.Module(dedent(""" def addOne (x: Float) : Float = x + 1.0 """)) x = dex.eval("2.5") y = dex.eval("[2, 3, 4]") assert str(m.addOne(x)) == "3.5" assert str(m.sum(y)) == "9"
def test_abstract_eval_simple(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jax.ShapeDtypeStruct((10, ), np.float32) output_shape = jax.eval_shape(add_two, x) assert output_shape.shape == (10, ) assert output_shape.dtype == np.int32
def test_vmap_jit(self): add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0')) x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose(jax.jit(jax.vmap(add_two))(x), x + 2.0)
def test_eval(self): cases = [ "2.5", "4", "[2, 3, 4]", ] for expr in cases: assert str(dex.eval(expr)) == expr
def test_vmap_nonzero_index(self): add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0')) x = jnp.linspace(jnp.array([0, 3]), jnp.array([5, 8]), num=4, dtype=jnp.float32) np.testing.assert_allclose( jax.vmap(add_two, in_axes=1, out_axes=1)(x), x + 2.0)
def test_vmap_unbatched_array(self): add_arrays = primitive( dex.eval( r'\x:((Fin 10)=>Float) y:((Fin 10)=>Float). for i. x.i + y.i')) x = jnp.arange(10, dtype=np.float32) y = jnp.linspace(jnp.arange(10), jnp.arange(10, 20), num=5, dtype=jnp.float32) np.testing.assert_allclose( jax.vmap(add_arrays, in_axes=[None, 0])(x, y), x + y)
def test_tuple_return(self): dex_func = dex.eval(r"\x: ((Fin 10) => Float). (x, 2. .* x, 3. .* x)") reference = lambda x: (x, 2 * x, 3 * x) x = np.arange(10, dtype=np.float32) dex_output = dex_func.compile()(x) reference_output = reference(x) self.assertEqual(len(dex_output), len(reference_output)) for dex_array, ref_array in zip(dex_output, reference_output): np.testing.assert_allclose(dex_array, ref_array)
def test_grad(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float). ' 'sum $ for i. (i_to_f $ ordinal i) * x.i * x.i')) def f_jax(x): return jnp.sum(jnp.arange(10.) * x**2) x = jnp.linspace(-0.2, 0.5, num=10) grad_dex = jax.grad(f_dex)(x) grad_jax = jax.grad(f_jax)(x) np.testing.assert_allclose(grad_dex, grad_jax)
def test_jvp(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). ' 'for i. x.i * x.i + 2.0 * y.i')) def f_jax(x, y): return x**2 + 2 * y x = jnp.arange(10.) y = jnp.linspace(-0.2, 0.5, num=10) u = jnp.linspace(0.1, 0.3, num=10) v = jnp.linspace(2.0, -5.0, num=10) output_dex, tangent_dex = jax.jvp(f_dex, (x, y), (u, v)) output_jax, tangent_jax = jax.jvp(f_jax, (x, y), (u, v)) np.testing.assert_allclose(output_dex, output_jax) np.testing.assert_allclose(tangent_dex, tangent_jax)
def test_grad_binary_function_jit(self): f_dex = primitive( dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). ' 'sum $ for i. x.i * x.i + 2.0 * y.i')) def f_jax(x, y): return jnp.sum(x**2 + 2 * y) def grad_dex(x, y): return jax.grad(f_dex)(x, y) def grad_jax(x, y): return jax.grad(f_jax)(x, y) x = jnp.arange(10.) y = jnp.linspace(-0.2, 0.5, num=10) np.testing.assert_allclose( jax.jit(grad_dex)(x, y), jax.jit(grad_jax)(x, y))
def test_jit_scale(): scale = primitive( dex.eval(r'\x:((Fin 10)=>Float) y:Float. for i. x.i * y')) x = jnp.arange((10, ), dtype=np.float32) np.testing.assert_allclose(scale(x, 5.0), x * 5.0)
def test_jit_array(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jnp.zeros((10, ), dtype=np.float32) np.testing.assert_allclose(jax.jit(add_two)(x), (x + 2.0).astype(np.int32))
def test_jit_scalar(): add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) x = jnp.zeros((), dtype=np.float32) np.testing.assert_allclose(jax.jit(add_two)(x), 2.0)
def test_impl_scalar(self): add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) x = jnp.zeros((), dtype=np.float32) np.testing.assert_allclose(add_two(x), x + 2.0)
def test_scalar_conversions(self): assert float(dex.eval("5.0")) == 5.0 assert int(dex.eval("5")) == 5 assert str(dex.Atom(5)) == "5" assert str(dex.Atom(5.0)) == "5."
def test(self): return check_atom(dex.eval(dex_source), reference, args_iter)
def test_scalar_conversions(): assert float(dex.eval("5.0")) == 5.0 assert int(dex.eval("5")) == 5
import sys import dex from glob import glob import cv2 from time import sleep import logging as log import datetime as dt # setup model dex.eval() # read landscape cascPath = "haarcascade_frontalface_default.xml" faceCascade = cv2.CascadeClassifier(cascPath) log.basicConfig(filename='webcam_predict_age_sex.log', level=log.INFO) if __name__ == '__main__': video_capture = cv2.VideoCapture(0) anterior = 0 while True: if not video_capture.isOpened(): print("Unable to load camera.") sleep(5) pass # capture frame-by-frame ret, frame = video_capture.read() gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
def test_impl_array(): add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. x.i + 2.0')) x = jnp.arange((10, ), dtype=np.float32) np.testing.assert_allclose(add_two(x), x + 2.0)