예제 #1
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = JavaInterpreter()
    interpreter.ast_size_check_frequency = 3
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {
    public static double score(double[] input) {
        double var0;
        if (((3.0) + ((3.0) + (subroutine0(input)))) == (3.0)) {
            var0 = 1.0;
        } else {
            if (((2.0) + ((2.0) + (subroutine1(input)))) == (3.0)) {
                var0 = 1.0;
            } else {
                if (((1.0) + ((1.0) + (subroutine2(input)))) == (3.0)) {
                    var0 = 1.0;
                } else {
                    if (((0.0) + ((0.0) + (subroutine3(input)))) == (3.0)) {
                        var0 = 1.0;
                    } else {
                        var0 = 1.0;
                    }
                }
            }
        }
        return var0;
    }
    public static double subroutine0(double[] input) {
        return (3.0) + ((3.0) + (1.0));
    }
    public static double subroutine1(double[] input) {
        return (2.0) + ((2.0) + (1.0));
    }
    public static double subroutine2(double[] input) {
        return (1.0) + ((1.0) + (1.0));
    }
    public static double subroutine3(double[] input) {
        return (0.0) + ((0.0) + (1.0));
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #2
0
def test_depth_threshold_with_bin_expr():
    expr = ast.NumVal(1)
    for _ in range(4):
        expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD)

    interpreter = JavaInterpreter()
    interpreter.ast_size_check_frequency = 3
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {
    public static double score(double[] input) {
        return (1.0) + ((1.0) + (subroutine0(input)));
    }
    public static double subroutine0(double[] input) {
        return (1.0) + ((1.0) + (1.0));
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #3
0
def test_depth_threshold_without_bin_expr():
    expr = ast.NumVal(1)
    for _ in range(4):
        expr = ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = JavaInterpreter()
    interpreter.ast_size_check_frequency = 2
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {
    public static double score(double[] input) {
        double var0;
        if ((1.0) == (1.0)) {
            var0 = 1.0;
        } else {
            if ((1.0) == (1.0)) {
                var0 = 1.0;
            } else {
                if ((1.0) == (1.0)) {
                    var0 = 1.0;
                } else {
                    if ((1.0) == (1.0)) {
                        var0 = 1.0;
                    } else {
                        var0 = 1.0;
                    }
                }
            }
        }
        return var0;
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)