def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ def score(input): if (1.0) == (1.0): var1 = 1.0 else: var1 = 2.0 if ((var1) + (2.0)) >= ((1.0) / (2.0)): var0 = 1.0 else: var0 = input[0] return var0 """ interpreter = PythonInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ def score(input): if (1.0) == (1.0): var1 = 1.0 else: var1 = 2.0 if (1.0) == ((var1) + (2.0)): if (1.0) == (1.0): var2 = 1.0 else: var2 = 2.0 if (1.0) == ((var2) + (2.0)): var0 = input[2] else: var0 = 2.0 else: var0 = 2.0 return var0 """ interpreter = PythonInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_raw_array(): expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) expected_code = """ def score(input): return [3.0, 4.0] """ interpreter = PythonInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ def score(input): return abs(-1.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_atan_expr(): expr = ast.AtanExpr(ast.NumVal(2.0)) expected_code = """ import math def score(input): return math.atan(2.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_pow_expr(): expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0)) expected_code = """ import math def score(input): return math.pow(2.0, 3.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() expected_code = """ import math def score(input): return math.log1p(2.0) """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_num_expr(): expr = ast.BinNumExpr( ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), ast.NumVal(2), ast.BinNumOpType.MUL) expected_code = """ def score(input): return ((input[0]) / (-2.0)) * (2.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_sqrt_fallback_expr(): expr = ast.SqrtExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() interpreter.sqrt_function_name = NotImplemented expected_code = """ import math def score(input): return math.pow(2.0, 0.5) """ assert_code_equal(interpreter.interpret(expr), expected_code)
def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) expected_code = """ import math def score(input): var0 = math.exp(1.0) return (var0) / (var0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_exp_fallback_expr(): expr = ast.ExpExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() interpreter.exponent_function_name = NotImplemented expected_code = """ import math def score(input): return math.pow(2.718281828459045, 2.0) """ assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_expression(): expr = ast.NumVal(1) for _ in range(120): expr = ast.BinNumExpr(expr, ast.NumVal(1), ast.BinNumOpType.ADD) interpreter = PythonInterpreter() result_code = interpreter.interpret(expr) result_code += """ result = score(None) """ scope = {} exec(result_code, scope) assert scope["result"] == 121
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ def score(input): if (1.0) == (input[0]): var0 = 2.0 else: var0 = 3.0 return var0 """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) expected_code = """ import math def sigmoid(x): if x < 0.0: z = math.exp(x) return z / (1.0 + z) return 1.0 / (1.0 + math.exp(-x)) def score(input): return sigmoid(2.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_num_expr(): expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.NumVal(1), ast.BinNumOpType.MUL) expected_code = """ def add_vectors(v1, v2): return [sum(i) for i in zip(v1, v2)] def mul_vector_number(v1, num): return [i * num for i in v1] def score(input): return mul_vector_number([1.0, 2.0], 1.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ def score(input): if (1.0) == (1.0): var0 = [1.0, 2.0] else: var0 = [3.0, 4.0] return var0 """ interpreter = PythonInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_atan_fallback_expr(): expr = ast.AtanExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() interpreter.atan_function_name = NotImplemented expected_code = ( """ def score(input): var1 = 2.0 var2 = abs(var1) if (var2) > (2.414213562373095): var0 = (1.0) / (var2) else: if (var2) > (0.66): var0 = ((var2) - (1.0)) / ((var2) + (1.0)) else: var0 = var2 var3 = var0 var4 = (var3) * (var3) if (var2) > (2.414213562373095): var5 = -1.0 else: var5 = 1.0 if (var2) <= (0.66): var6 = 0.0 else: if (var2) > (2.414213562373095): var6 = 1.5707963267948968 else: var6 = 0.7853981633974484 if (var1) < (0.0): var7 = -1.0 else: var7 = 1.0 return (((((var3) * ((var4) * ((((var4) * (((var4) * (((var4) * """ """(((var4) * (-0.8750608600031904)) - (16.157537187333652))) - """ """(75.00855792314705))) - (122.88666844901361))) - """ """(64.85021904942025)) / ((194.5506571482614) + ((var4) * """ """((485.3903996359137) + ((var4) * ((432.88106049129027) + """ """((var4) * ((165.02700983169885) + ((var4) * """ """((24.858464901423062) + (var4))))))))))))) + (var3)) * """ """(var5)) + (var6)) * (var7)""") assert_code_equal(interpreter.interpret(expr), expected_code)
def test_softmax_expr(): expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) expected_code = """ import math def softmax(x): m = max(x) exps = [math.exp(i - m) for i in x] s = sum(exps) for idx, _ in enumerate(exps): exps[idx] /= s return exps def score(input): return softmax([2.0, 3.0]) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_expr(): expr = ast.BinVectorExpr( ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]), ast.BinNumOpType.ADD) interpreter = PythonInterpreter() expected_code = """ def add_vectors(v1, v2): return [sum(i) for i in zip(v1, v2)] def mul_vector_number(v1, num): return [i * num for i in v1] def score(input): return add_vectors([1.0, 2.0], [3.0, 4.0]) """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_fallback_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() interpreter.log1p_function_name = NotImplemented expected_code = """ import math def score(input): var1 = 2.0 var2 = (1.0) + (var1) var3 = (var2) - (1.0) if (var3) == (0.0): var0 = var1 else: var0 = ((var1) * (math.log(var2))) / (var3) return var0 """ assert_code_equal(interpreter.interpret(expr), expected_code)
def test_tanh_fallback_expr(): expr = ast.TanhExpr(ast.NumVal(2.0)) interpreter = PythonInterpreter() interpreter.tanh_function_name = NotImplemented expected_code = """ import math def score(input): var1 = 2.0 if (var1) > (44.0): var0 = 1.0 else: if (var1) < (-44.0): var0 = -1.0 else: var0 = (1.0) - ((2.0) / ((math.exp((2.0) * (var1))) + (1.0))) return var0 """ assert_code_equal(interpreter.interpret(expr), expected_code)