Esempio n. 1
0
def test_function_call():
    m = dex.Module(dedent("""
  def addOne (x: Float) : Float = x + 1.0
  """))
    x = dex.eval("2.5")
    y = dex.eval("[2, 3, 4]")
    assert str(m.addOne(x)) == "3.5"
    assert str(m.sum(y)) == "9"
Esempio n. 2
0
def test_abstract_eval_simple():
    add_two = primitive(
        dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0'))
    x = jax.ShapeDtypeStruct((10, ), np.float32)
    output_shape = jax.eval_shape(add_two, x)
    assert output_shape.shape == (10, )
    assert output_shape.dtype == np.int32
Esempio n. 3
0
 def test_vmap_jit(self):
     add_two = primitive(dex.eval(r'\x:((Fin 2)=>Float). for i. x.i + 2.0'))
     x = jnp.linspace(jnp.array([0, 3]),
                      jnp.array([5, 8]),
                      num=4,
                      dtype=jnp.float32)
     np.testing.assert_allclose(jax.jit(jax.vmap(add_two))(x), x + 2.0)
Esempio n. 4
0
 def test_eval(self):
     cases = [
         "2.5",
         "4",
         "[2, 3, 4]",
     ]
     for expr in cases:
         assert str(dex.eval(expr)) == expr
Esempio n. 5
0
 def test_vmap_nonzero_index(self):
     add_two = primitive(dex.eval(r'\x:((Fin 4)=>Float). for i. x.i + 2.0'))
     x = jnp.linspace(jnp.array([0, 3]),
                      jnp.array([5, 8]),
                      num=4,
                      dtype=jnp.float32)
     np.testing.assert_allclose(
         jax.vmap(add_two, in_axes=1, out_axes=1)(x), x + 2.0)
Esempio n. 6
0
 def test_vmap_unbatched_array(self):
     add_arrays = primitive(
         dex.eval(
             r'\x:((Fin 10)=>Float) y:((Fin 10)=>Float). for i. x.i + y.i'))
     x = jnp.arange(10, dtype=np.float32)
     y = jnp.linspace(jnp.arange(10),
                      jnp.arange(10, 20),
                      num=5,
                      dtype=jnp.float32)
     np.testing.assert_allclose(
         jax.vmap(add_arrays, in_axes=[None, 0])(x, y), x + y)
Esempio n. 7
0
    def test_tuple_return(self):
        dex_func = dex.eval(r"\x: ((Fin 10) => Float). (x, 2. .* x, 3. .* x)")
        reference = lambda x: (x, 2 * x, 3 * x)

        x = np.arange(10, dtype=np.float32)

        dex_output = dex_func.compile()(x)
        reference_output = reference(x)

        self.assertEqual(len(dex_output), len(reference_output))
        for dex_array, ref_array in zip(dex_output, reference_output):
            np.testing.assert_allclose(dex_array, ref_array)
Esempio n. 8
0
    def test_grad(self):
        f_dex = primitive(
            dex.eval(r'\x:((Fin 10) => Float). '
                     'sum $ for i. (i_to_f $ ordinal i) * x.i * x.i'))

        def f_jax(x):
            return jnp.sum(jnp.arange(10.) * x**2)

        x = jnp.linspace(-0.2, 0.5, num=10)

        grad_dex = jax.grad(f_dex)(x)
        grad_jax = jax.grad(f_jax)(x)

        np.testing.assert_allclose(grad_dex, grad_jax)
Esempio n. 9
0
    def test_jvp(self):
        f_dex = primitive(
            dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). '
                     'for i. x.i * x.i + 2.0 * y.i'))

        def f_jax(x, y):
            return x**2 + 2 * y

        x = jnp.arange(10.)
        y = jnp.linspace(-0.2, 0.5, num=10)
        u = jnp.linspace(0.1, 0.3, num=10)
        v = jnp.linspace(2.0, -5.0, num=10)

        output_dex, tangent_dex = jax.jvp(f_dex, (x, y), (u, v))
        output_jax, tangent_jax = jax.jvp(f_jax, (x, y), (u, v))

        np.testing.assert_allclose(output_dex, output_jax)
        np.testing.assert_allclose(tangent_dex, tangent_jax)
Esempio n. 10
0
    def test_grad_binary_function_jit(self):
        f_dex = primitive(
            dex.eval(r'\x:((Fin 10) => Float) y:((Fin 10) => Float). '
                     'sum $ for i. x.i * x.i + 2.0 * y.i'))

        def f_jax(x, y):
            return jnp.sum(x**2 + 2 * y)

        def grad_dex(x, y):
            return jax.grad(f_dex)(x, y)

        def grad_jax(x, y):
            return jax.grad(f_jax)(x, y)

        x = jnp.arange(10.)
        y = jnp.linspace(-0.2, 0.5, num=10)

        np.testing.assert_allclose(
            jax.jit(grad_dex)(x, y),
            jax.jit(grad_jax)(x, y))
Esempio n. 11
0
def test_jit_scale():
    scale = primitive(
        dex.eval(r'\x:((Fin 10)=>Float) y:Float. for i. x.i * y'))
    x = jnp.arange((10, ), dtype=np.float32)
    np.testing.assert_allclose(scale(x, 5.0), x * 5.0)
Esempio n. 12
0
def test_jit_array():
    add_two = primitive(
        dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0'))
    x = jnp.zeros((10, ), dtype=np.float32)
    np.testing.assert_allclose(jax.jit(add_two)(x), (x + 2.0).astype(np.int32))
Esempio n. 13
0
def test_jit_scalar():
    add_two = primitive(dex.eval(r'\x:Float. x + 2.0'))
    x = jnp.zeros((), dtype=np.float32)
    np.testing.assert_allclose(jax.jit(add_two)(x), 2.0)
Esempio n. 14
0
 def test_impl_scalar(self):
     add_two = primitive(dex.eval(r'\x:Float. x + 2.0'))
     x = jnp.zeros((), dtype=np.float32)
     np.testing.assert_allclose(add_two(x), x + 2.0)
Esempio n. 15
0
 def test_scalar_conversions(self):
     assert float(dex.eval("5.0")) == 5.0
     assert int(dex.eval("5")) == 5
     assert str(dex.Atom(5)) == "5"
     assert str(dex.Atom(5.0)) == "5."
Esempio n. 16
0
 def test(self):
     return check_atom(dex.eval(dex_source), reference, args_iter)
Esempio n. 17
0
def test_scalar_conversions():
    assert float(dex.eval("5.0")) == 5.0
    assert int(dex.eval("5")) == 5
Esempio n. 18
0
import sys
import dex
from glob import glob

import cv2
from time import sleep
import logging as log
import datetime as dt

# setup model
dex.eval()

# read landscape
cascPath = "haarcascade_frontalface_default.xml"
faceCascade = cv2.CascadeClassifier(cascPath)
log.basicConfig(filename='webcam_predict_age_sex.log', level=log.INFO)

if __name__ == '__main__':

    video_capture = cv2.VideoCapture(0)
    anterior = 0

    while True:
        if not video_capture.isOpened():
            print("Unable to load camera.")
            sleep(5)
            pass

        # capture frame-by-frame
        ret, frame = video_capture.read()
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
Esempio n. 19
0
def test_impl_array():
    add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. x.i + 2.0'))
    x = jnp.arange((10, ), dtype=np.float32)
    np.testing.assert_allclose(add_two(x), x + 2.0)