def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ public class Model { public static double score(double[] input) { double var0; double var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if (((var1) + (2.0)) >= ((1.0) / (2.0))) { var0 = 1.0; } else { var0 = input[0]; } return var0; } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_softmax_expr(): expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double[] score(double[] input) { return softmax(new double[] {2.0, 3.0}); } private static double[] softmax(double[] x) { int size = x.length; double[] result = new double[size]; double max = x[0]; for (int i = 1; i < size; ++i) { if (x[i] > max) max = x[i]; } double sum = 0.0; for (int i = 0; i < size; ++i) { result[i] = Math.exp(x[i] - max); sum += result[i]; } for (int i = 0; i < size; ++i) result[i] /= sum; return result; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_num_expr(): expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.NumVal(1), ast.BinNumOpType.MUL) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double[] score(double[] input) { return mulVectorNumber(new double[] {1.0, 2.0}, 1.0); } private static double[] addVectors(double[] v1, double[] v2) { double[] result = new double[v1.length]; for (int i = 0; i < v1.length; i++) { result[i] = v1[i] + v2[i]; } return result; } private static double[] mulVectorNumber(double[] v1, double num) { double[] result = new double[v1.length]; for (int i = 0; i < v1.length; i++) { result[i] = v1[i] * num; } return result; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for j in range(4): inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = JavaInterpreter() interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ public class Model { public static double score(double[] input) { double var0; if (((3.0) + ((3.0) + (subroutine0(input)))) == (3.0)) { var0 = 1.0; } else { if (((2.0) + ((2.0) + (subroutine1(input)))) == (3.0)) { var0 = 1.0; } else { if (((1.0) + ((1.0) + (subroutine2(input)))) == (3.0)) { var0 = 1.0; } else { if (((0.0) + ((0.0) + (subroutine3(input)))) == (3.0)) { var0 = 1.0; } else { var0 = 1.0; } } } } return var0; } public static double subroutine0(double[] input) { return (3.0) + ((3.0) + (1.0)); } public static double subroutine1(double[] input) { return (2.0) + ((2.0) + (1.0)); } public static double subroutine2(double[] input) { return (1.0) + ((1.0) + (1.0)); } public static double subroutine3(double[] input) { return (0.0) + ((0.0) + (1.0)); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_raw_array(): expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) expected_code = """ public class Model { public static double[] score(double[] input) { return new double[] {3.0, 4.0}; } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_atan_expr(): expr = ast.AtanExpr(ast.NumVal(2.0)) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { return Math.atan(2.0); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_package_name(): expr = ast.NumVal(1) interpreter = JavaInterpreter(package_name="foo.bar") expected_code = """ package foo.bar; public class Model { public static double score(double[] input) { return 1.0; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_ignores_subroutine_expr(): expr = ast.BinNumExpr( ast.FeatureRef(0), ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD), ast.BinNumOpType.MUL) expected_code = """ public class Model { public static double score(double[] input) { return (input[0]) * ((1.0) + (2.0)); } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { double var0; var0 = Math.exp(1.0); return (var0) / (var0); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_num_expr(): expr = ast.BinNumExpr( ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), ast.NumVal(2), ast.BinNumOpType.MUL) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { return ((input[0]) / (-2.0)) * (2.0); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ public class Model { public static double score(double[] input) { double var0; double var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if ((1.0) == ((var1) + (2.0))) { double var2; if ((1.0) == (1.0)) { var2 = 1.0; } else { var2 = 2.0; } if ((1.0) == ((var2) + (2.0))) { var0 = input[2]; } else { var0 = 2.0; } } else { var0 = 2.0; } return var0; } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { return sigmoid(2.0); } private static double sigmoid(double x) { if (x < 0.0) { double z = Math.exp(x); return z / (1.0 + z); } return 1.0 / (1.0 + Math.exp(-x)); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_depth_threshold_with_bin_expr(): expr = ast.NumVal(1) for _ in range(4): expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD) interpreter = JavaInterpreter() interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ public class Model { public static double score(double[] input) { return (1.0) + ((1.0) + (subroutine0(input))); } public static double subroutine0(double[] input) { return (1.0) + ((1.0) + (1.0)); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ public class Model { public static double[] score(double[] input) { double[] var0; if ((1.0) == (1.0)) { var0 = new double[] {1.0, 2.0}; } else { var0 = new double[] {3.0, 4.0}; } return var0; } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { double var0; if ((1.0) == (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_depth_threshold_without_bin_expr(): expr = ast.NumVal(1) for _ in range(4): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = JavaInterpreter() interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ public class Model { public static double score(double[] input) { double var0; if ((1.0) == (1.0)) { var0 = 1.0; } else { if ((1.0) == (1.0)) { var0 = 1.0; } else { if ((1.0) == (1.0)) { var0 = 1.0; } else { if ((1.0) == (1.0)) { var0 = 1.0; } else { var0 = 1.0; } } } } return var0; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)