Beispiel #1
0
def test_single_condition():
    estimator = ensemble.RandomForestRegressor(n_estimators=2, random_state=1)

    estimator.fit([[1], [2]], [1, 2])

    assembler = assemblers.RandomForestModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.BinNumExpr(
        ast.BinNumExpr(
            ast.SubroutineExpr(ast.NumVal(1.0)),
            ast.SubroutineExpr(
                ast.IfExpr(
                    ast.CompExpr(ast.FeatureRef(0),
                                 ast.NumVal(1.5), ast.CompOpType.LTE),
                    ast.NumVal(1.0), ast.NumVal(2.0))), ast.BinNumOpType.ADD),
        ast.NumVal(0.5), ast.BinNumOpType.MUL)

    assert utils.cmp_exprs(actual, expected)
Beispiel #2
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    expected_code = """
score <- function(input) {
    if ((1.0) == (input[1])) {
        var0 <- 2.0
    } else {
        var0 <- 3.0
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #3
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
function Score([double[]] $InputVector) {
    [double]$var0 = 0
    if ((1) -eq ($InputVector[0])) {
        $var0 = 2
    } else {
        $var0 = 3
    }
    return $var0
}
"""

    interpreter = PowershellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #4
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    interpreter = interpreters.GoInterpreter()

    expected_code = """
func score(input []float64) float64 {
    var var0 float64
    if (1.0) == (input[0]) {
        var0 = 2.0
    } else {
        var0 = 3.0
    }
    return var0
}"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #5
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
def score(input)
    if (1.0) == (1.0)
        var0 = [1.0, 2.0]
    else
        var0 = [3.0, 4.0]
    end
    var0
end
"""

    interpreter = RubyInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #6
0
def test_count_exprs():
    assert ast.count_exprs(
        ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2),
                       ast.BinNumOpType.ADD)) == 3

    assert ast.count_exprs(ast.ExpExpr(ast.NumVal(2))) == 2

    assert ast.count_exprs(
        ast.VectorVal([ast.NumVal(2),
                       ast.TanhExpr(ast.NumVal(3))])) == 4

    assert ast.count_exprs(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
            ast.NumVal(3),
            ast.NumVal(4),
        )) == 6

    assert ast.count_exprs(ast.NumVal(1)) == 1
Beispiel #7
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
double score(List<double> input) {
    double var0;
    if ((1.0) == (input[0])) {
        var0 = 2.0;
    } else {
        var0 = 3.0;
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #8
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
#include <string.h>
void score(double * input, double * output) {
    double var0[2];
    if ((1) == (1)) {
        memcpy(var0, (double[]){1, 2}, 2 * sizeof(double));
    } else {
        memcpy(var0, (double[]){3, 4}, 2 * sizeof(double));
    }
    memcpy(output, var0, 2 * sizeof(double));
}"""
    interpreter = interpreters.CInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #9
0
def test_deep_mixed_exprs_not_reaching_threshold():
    expr = ast.NumVal(1)
    for _ in range(4):
        inner = ast.NumVal(1)
        for __ in range(2):
            inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(
                inner, ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1),
            expr)

    interpreter = interpreters.JavaInterpreter()
    interpreter.ast_size_check_frequency = 3
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {
    public static double score(double[] input) {
        double var0;
        if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
            var0 = 1.0;
        } else {
            if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
                var0 = 1.0;
            } else {
                if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
                    var0 = 1.0;
                } else {
                    if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
                        var0 = 1.0;
                    } else {
                        var0 = 1.0;
                    }
                }
            }
        }
        return var0;
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #10
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
score <- function(input) {
    if ((1.0) == (1.0)) {
        var0 <- c(1.0, 2.0)
    } else {
        var0 <- c(3.0, 4.0)
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
function score(input) {
    var var0;
    if ((1.0) === (input[0])) {
        var0 = 2.0;
    } else {
        var0 = 3.0;
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #12
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    let var0: f64;
    if (1.0_f64) == (input[0]) {
        var0 = 2.0_f64;
    } else {
        var0 = 3.0_f64;
    }
    var0
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #13
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
function score(input) {
    var var0;
    if ((1) == (1)) {
        var0 = [1, 2];
    } else {
        var0 = [3, 4];
    }
    return var0;
}
"""

    interpreter = interpreters.JavascriptInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #14
0
 def kernel_ast(sup_vec_value):
     feature_norm = ast.SqrtExpr(
         ast.BinNumExpr(
             ast.FeatureRef(0),
             ast.FeatureRef(0),
             ast.BinNumOpType.MUL),
         to_reuse=True)
     return ast.BinNumExpr(
         ast.BinNumExpr(
             ast.NumVal(sup_vec_value),
             ast.FeatureRef(0),
             ast.BinNumOpType.MUL),
         ast.IfExpr(
             ast.CompExpr(
                 feature_norm,
                 ast.NumVal(0.0),
                 ast.CompOpType.EQ),
             ast.NumVal(1.0),
             feature_norm),
         ast.BinNumOpType.DIV)
Beispiel #15
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
function Score([double[]] $InputVector) {
    [double[]]$var0 = @(0)
    if ((1) -eq (1)) {
        $var0 = @($(1), $(2))
    } else {
        $var0 = @($(3), $(4))
    }
    return $var0
}
"""

    interpreter = PowershellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #16
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
module Model where
score :: [Double] -> Double
score input =
    func0
    where
        func0 =
            if (1.0) == ((input) !! (0)) then
                2.0
            else
                3.0
"""

    interpreter = HaskellInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #17
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
<?php
function score(array $input) {
    $var0 = null;
    if ((1) === ($input[0])) {
        $var0 = 2;
    } else {
        $var0 = 3;
    }
    return $var0;
}
"""

    interpreter = PhpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #18
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
List<double> score(List<double> input) {
    List<double> var0;
    if ((1.0) == (1.0)) {
        var0 = [1.0, 2.0];
    } else {
        var0 = [3.0, 4.0];
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #19
0
def test_multi_output():
    expr = ast.IfExpr(
            ast.CompExpr(
                ast.NumVal(1),
                ast.NumVal(1),
                ast.CompOpType.EQ),
            ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
            ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
def score(input):
    if (1) == (1):
        var0 = [1, 2]
    else:
        var0 = [3, 4]
    return var0
    """

    interpreter = interpreters.PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #20
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
func score(input []float64) []float64 {
    var var0 []float64
    if (1.0) == (1.0) {
        var0 = []float64{1.0, 2.0}
    } else {
        var0 = []float64{3.0, 4.0}
    }
    return var0
}
"""

    interpreter = GoInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #21
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = CustomFSharpInterpreter()

    expected_code = """
let score (input : double list) =
    let func0 =
        (3.0) + ((3.0) + (1.0))
    let func1 =
        (2.0) + ((2.0) + (1.0))
    let func2 =
        (1.0) + ((1.0) + (1.0))
    let func3 =
        (0.0) + ((0.0) + (1.0))
    let func4 =
        if (((3.0) + ((3.0) + (func0))) = (3.0)) then
            1.0
        else
            if (((2.0) + ((2.0) + (func1))) = (3.0)) then
                1.0
            else
                if (((1.0) + ((1.0) + (func2))) = (3.0)) then
                    1.0
                else
                    if (((0.0) + ((0.0) + (func3))) = (3.0)) then
                        1.0
                    else
                        1.0
    func4
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #22
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for i in range(4):
            inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = CustomRInterpreter()

    expected_code = """
score <- function(input) {
    var1 <- (1) + ((1) + (1))
    if (((1) + ((1) + (var1))) == (1)) {
        var0 <- 1
    } else {
        var2 <- (1) + ((1) + (1))
        if (((1) + ((1) + (var2))) == (1)) {
            var0 <- 1
        } else {
            var3 <- (1) + ((1) + (1))
            if (((1) + ((1) + (var3))) == (1)) {
                var0 <- 1
            } else {
                var4 <- (1) + ((1) + (1))
                if (((1) + ((1) + (var4))) == (1)) {
                    var0 <- 1
                } else {
                    var0 <- 1
                }
            }
        }
    }
    return(var0)
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #23
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
<?php
function score(array $input) {
    $var0 = array();
    if ((1.0) === (1.0)) {
        $var0 = array(1.0, 2.0);
    } else {
        $var0 = array(3.0, 4.0);
    }
    return $var0;
}
"""

    interpreter = PhpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #24
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
module Model where
score :: [Double] -> [Double]
score input =
    func0
    where
        func0 =
            if (1.0) == (1.0) then
                [1.0, 2.0]
            else
                [3.0, 4.0]
"""

    interpreter = HaskellInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #25
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
public class Model {
    public static double[] score(double[] input) {
        double[] var0;
        if ((1.0) == (1.0)) {
            var0 = new double[] {1.0, 2.0};
        } else {
            var0 = new double[] {3.0, 4.0};
        }
        return var0;
    }
}"""

    interpreter = JavaInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #26
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
Module Model
Function score(ByRef input_vector() As Double) As Double
    Dim var0 As Double
    If (1) == (input_vector(0)) Then
        var0 = 2
    Else
        var0 = 3
    End If
    score = var0
End Function
End Module
"""

    interpreter = VisualBasicInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #27
0
def test_depth_threshold_without_bin_expr():
    expr = ast.NumVal(1)
    for i in range(4):
        expr = ast.IfExpr(
            ast.CompExpr(
                ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1),
            expr)

    interpreter = interpreters.JavaInterpreter()
    interpreter.bin_depth_threshold = 2
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {

    public static double score(double[] input) {
        double var0;
        if ((1) == (1)) {
            var0 = 1;
        } else {
            if ((1) == (1)) {
                var0 = 1;
            } else {
                if ((1) == (1)) {
                    var0 = 1;
                } else {
                    if ((1) == (1)) {
                        var0 = 1;
                    } else {
                        var0 = 1;
                    }
                }
            }
        }
        return var0;
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #28
0
def test_deep_mixed_exprs_not_reaching_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for i in range(2):
            inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(
                inner, ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1),
            expr)

    interpreter = CustomDartInterpreter()

    expected_code = """
double score(List<double> input) {
    double var0;
    if (((1) + ((1) + (1))) == (1)) {
        var0 = 1;
    } else {
        if (((1) + ((1) + (1))) == (1)) {
            var0 = 1;
        } else {
            if (((1) + ((1) + (1))) == (1)) {
                var0 = 1;
            } else {
                if (((1) + ((1) + (1))) == (1)) {
                    var0 = 1;
                } else {
                    var0 = 1;
                }
            }
        }
    }
    return var0;
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #29
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    interpreter = JavaInterpreter()

    expected_code = """
public class Model {
    public static double score(double[] input) {
        double var0;
        if ((1.0) == (input[0])) {
            var0 = 2.0;
        } else {
            var0 = 3.0;
        }
        return var0;
    }
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Beispiel #30
0
def test_multi_output():
    expr = ast.SubroutineExpr(
        ast.IfExpr(
            ast.CompExpr(
                ast.NumVal(1),
                ast.NumVal(1),
                ast.CompOpType.EQ),
            ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
            ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))

    expected_code = """
func score(input []float64) []float64 {
    var var0 []float64
    if (1) == (1) {
        var0 = []float64{1, 2}
    } else {
        var0 = []float64{3, 4}
    }
    return var0
}"""
    interpreter = interpreters.GoInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)