Ejemplo n.º 1
0
def test_dependable_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1),
                         ast.NumVal(1),
                         ast.CompOpType.EQ),
            ast.NumVal(1),
            ast.NumVal(2)),
        ast.NumVal(2),
        ast.BinNumOpType.ADD)

    right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
    bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)

    expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

    expected_code = """
def score(input):
    if (1.0) == (1.0):
        var1 = 1.0
    else:
        var1 = 2.0
    if ((var1) + (2.0)) >= ((1.0) / (2.0)):
        var0 = 1.0
    else:
        var0 = input[0]
    return var0
    """

    interpreter = PythonInterpreter()

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 2
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)

    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)

    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))

    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
def score(input):
    if (1.0) == (1.0):
        var1 = 1.0
    else:
        var1 = 2.0
    if (1.0) == ((var1) + (2.0)):
        if (1.0) == (1.0):
            var2 = 1.0
        else:
            var2 = 2.0
        if (1.0) == ((var2) + (2.0)):
            var0 = input[2]
        else:
            var0 = 2.0
    else:
        var0 = 2.0
    return var0
    """

    interpreter = PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 3
0
def test_raw_array():
    expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])

    expected_code = """
def score(input):
    return [3.0, 4.0]
    """

    interpreter = PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 4
0
def test_abs_expr():
    expr = ast.AbsExpr(ast.NumVal(-1.0))

    expected_code = """
def score(input):
    return abs(-1.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 5
0
def test_atan_expr():
    expr = ast.AtanExpr(ast.NumVal(2.0))

    expected_code = """
import math
def score(input):
    return math.atan(2.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 6
0
def test_pow_expr():
    expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0))

    expected_code = """
import math
def score(input):
    return math.pow(2.0, 3.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 7
0
def test_log1p_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()

    expected_code = """
import math
def score(input):
    return math.log1p(2.0)
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 8
0
def test_bin_num_expr():
    expr = ast.BinNumExpr(
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2),
                       ast.BinNumOpType.DIV), ast.NumVal(2),
        ast.BinNumOpType.MUL)

    expected_code = """
def score(input):
    return ((input[0]) / (-2.0)) * (2.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 9
0
def test_sqrt_fallback_expr():
    expr = ast.SqrtExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()
    interpreter.sqrt_function_name = NotImplemented

    expected_code = """
import math
def score(input):
    return math.pow(2.0, 0.5)
"""

    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 10
0
def test_reused_expr():
    reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True)
    expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV)

    expected_code = """
import math
def score(input):
    var0 = math.exp(1.0)
    return (var0) / (var0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 11
0
def test_exp_fallback_expr():
    expr = ast.ExpExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()
    interpreter.exponent_function_name = NotImplemented

    expected_code = """
import math
def score(input):
    return math.pow(2.718281828459045, 2.0)
"""

    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 12
0
def test_deep_expression():
    expr = ast.NumVal(1)
    for _ in range(120):
        expr = ast.BinNumExpr(expr, ast.NumVal(1), ast.BinNumOpType.ADD)

    interpreter = PythonInterpreter()

    result_code = interpreter.interpret(expr)
    result_code += """
result = score(None)
"""

    scope = {}
    exec(result_code, scope)

    assert scope["result"] == 121
Ejemplo n.º 13
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
def score(input):
    if (1.0) == (input[0]):
        var0 = 2.0
    else:
        var0 = 3.0
    return var0
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 14
0
def test_sigmoid_expr():
    expr = ast.SigmoidExpr(ast.NumVal(2.0))

    expected_code = """
import math
def sigmoid(x):
    if x < 0.0:
        z = math.exp(x)
        return z / (1.0 + z)
    return 1.0 / (1.0 + math.exp(-x))
def score(input):
    return sigmoid(2.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 15
0
def test_bin_vector_num_expr():
    expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1),
                                               ast.NumVal(2)]), ast.NumVal(1),
                                ast.BinNumOpType.MUL)

    expected_code = """
def add_vectors(v1, v2):
    return [sum(i) for i in zip(v1, v2)]
def mul_vector_number(v1, num):
    return [i * num for i in v1]
def score(input):
    return mul_vector_number([1.0, 2.0], 1.0)
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 16
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
def score(input):
    if (1.0) == (1.0):
        var0 = [1.0, 2.0]
    else:
        var0 = [3.0, 4.0]
    return var0
    """

    interpreter = PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 17
0
def test_atan_fallback_expr():
    expr = ast.AtanExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()
    interpreter.atan_function_name = NotImplemented

    expected_code = (
        """
def score(input):
    var1 = 2.0
    var2 = abs(var1)
    if (var2) > (2.414213562373095):
        var0 = (1.0) / (var2)
    else:
        if (var2) > (0.66):
            var0 = ((var2) - (1.0)) / ((var2) + (1.0))
        else:
            var0 = var2
    var3 = var0
    var4 = (var3) * (var3)
    if (var2) > (2.414213562373095):
        var5 = -1.0
    else:
        var5 = 1.0
    if (var2) <= (0.66):
        var6 = 0.0
    else:
        if (var2) > (2.414213562373095):
            var6 = 1.5707963267948968
        else:
            var6 = 0.7853981633974484
    if (var1) < (0.0):
        var7 = -1.0
    else:
        var7 = 1.0
    return (((((var3) * ((var4) * ((((var4) * (((var4) * (((var4) * """
        """(((var4) * (-0.8750608600031904)) - (16.157537187333652))) - """
        """(75.00855792314705))) - (122.88666844901361))) - """
        """(64.85021904942025)) / ((194.5506571482614) + ((var4) * """
        """((485.3903996359137) + ((var4) * ((432.88106049129027) + """
        """((var4) * ((165.02700983169885) + ((var4) * """
        """((24.858464901423062) + (var4))))))))))))) + (var3)) * """
        """(var5)) + (var6)) * (var7)""")

    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 18
0
def test_softmax_expr():
    expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)])

    expected_code = """
import math
def softmax(x):
    m = max(x)
    exps = [math.exp(i - m) for i in x]
    s = sum(exps)
    for idx, _ in enumerate(exps):
        exps[idx] /= s
    return exps
def score(input):
    return softmax([2.0, 3.0])
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 19
0
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
        ast.BinNumOpType.ADD)

    interpreter = PythonInterpreter()

    expected_code = """
def add_vectors(v1, v2):
    return [sum(i) for i in zip(v1, v2)]
def mul_vector_number(v1, num):
    return [i * num for i in v1]
def score(input):
    return add_vectors([1.0, 2.0], [3.0, 4.0])
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 20
0
def test_log1p_fallback_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()
    interpreter.log1p_function_name = NotImplemented

    expected_code = """
import math
def score(input):
    var1 = 2.0
    var2 = (1.0) + (var1)
    var3 = (var2) - (1.0)
    if (var3) == (0.0):
        var0 = var1
    else:
        var0 = ((var1) * (math.log(var2))) / (var3)
    return var0
"""

    assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 21
0
def test_tanh_fallback_expr():
    expr = ast.TanhExpr(ast.NumVal(2.0))

    interpreter = PythonInterpreter()
    interpreter.tanh_function_name = NotImplemented

    expected_code = """
import math
def score(input):
    var1 = 2.0
    if (var1) > (44.0):
        var0 = 1.0
    else:
        if (var1) < (-44.0):
            var0 = -1.0
        else:
            var0 = (1.0) - ((2.0) / ((math.exp((2.0) * (var1))) + (1.0)))
    return var0
"""

    assert_code_equal(interpreter.interpret(expr), expected_code)