示例#1
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
def score(input):
    if (1.0) == (1.0):
        var0 = [1.0, 2.0]
    else:
        var0 = [3.0, 4.0]
    return var0
    """

    interpreter = PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#2
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
let score (input : double list) =
    let func0 =
        if (1.0) = (input.[0]) then
            2.0
        else
            3.0
    func0
"""

    interpreter = FSharpInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#3
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = CustomDartInterpreter()

    expected_code = """
double score(List<double> input) {
    double var0;
    double var1;
    var1 = (3.0) + ((3.0) + (1.0));
    if (((3.0) + ((3.0) + (var1))) == (3.0)) {
        var0 = 1.0;
    } else {
        double var2;
        var2 = (2.0) + ((2.0) + (1.0));
        if (((2.0) + ((2.0) + (var2))) == (3.0)) {
            var0 = 1.0;
        } else {
            double var3;
            var3 = (1.0) + ((1.0) + (1.0));
            if (((1.0) + ((1.0) + (var3))) == (3.0)) {
                var0 = 1.0;
            } else {
                double var4;
                var4 = (0.0) + ((0.0) + (1.0));
                if (((0.0) + ((0.0) + (var4))) == (3.0)) {
                    var0 = 1.0;
                } else {
                    var0 = 1.0;
                }
            }
        }
    }
    return var0;
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#4
0
def test_ransac_custom_base_estimator():
    base_estimator = DecisionTreeRegressor()
    estimator = RANSACRegressor(base_estimator=base_estimator, random_state=1)
    estimator.fit([[1], [2], [3]], [1, 2, 3])

    assembler = RANSACModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.IfExpr(
        ast.CompExpr(
            ast.FeatureRef(0),
            ast.NumVal(2.5),
            ast.CompOpType.LTE),
        ast.NumVal(2.0),
        ast.NumVal(3.0))

    assert cmp_exprs(actual, expected)
示例#5
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    interpreter = PythonInterpreter()

    expected_code = """
def score(input):
    if (1.0) == (input[0]):
        var0 = 2.0
    else:
        var0 = 3.0
    return var0
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#6
0
def test_single_condition():
    estimator = DecisionTreeRegressor()

    estimator.fit([[1], [2]], [1, 2])

    assembler = TreeModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.IfExpr(
        ast.CompExpr(
            ast.FeatureRef(0),
            ast.NumVal(1.5),
            ast.CompOpType.LTE),
        ast.NumVal(1.0),
        ast.NumVal(2.0))

    assert cmp_exprs(actual, expected)
示例#7
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    interpreter = GoInterpreter()

    expected_code = """
func score(input []float64) float64 {
    var var0 float64
    if (1.0) == (input[0]) {
        var0 = 2.0
    } else {
        var0 = 3.0
    }
    return var0
}"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#8
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
let score (input : double list) =
    let func0 =
        if ((1.0) = (1.0)) then
            [1.0; 2.0]
        else
            [3.0; 4.0]
    func0
"""

    interpreter = FSharpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#9
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
func score(input []float64) []float64 {
    var var0 []float64
    if (1.0) == (1.0) {
        var0 = []float64{1.0, 2.0}
    } else {
        var0 = []float64{3.0, 4.0}
    }
    return var0
}"""
    interpreter = GoInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#10
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    interpreter = interpreters.CInterpreter()

    expected_code = """
double score(double * input) {
    double var0;
    if ((1) == (input[0])) {
        var0 = 2;
    } else {
        var0 = 3;
    }
    return var0;
}"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#11
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
def score(input)
    if (1.0) == (input[0])
        var0 = 2.0
    else
        var0 = 3.0
    end
    var0
end
"""

    interpreter = RubyInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#12
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
score <- function(input) {
    if ((1.0) == (input[1])) {
        var0 <- 2.0
    } else {
        var0 <- 3.0
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#13
0
def test_deep_mixed_exprs_not_reaching_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for _ in range(2):
            inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(
                inner, ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1),
            expr)

    interpreter = interpreters.JavaInterpreter()
    interpreter.bin_depth_threshold = 2
    interpreter.ast_size_per_subroutine_threshold = 1

    expected_code = """
public class Model {

    public static double score(double[] input) {
        double var0;
        if (((1) + ((1) + (1))) == (1)) {
            var0 = 1;
        } else {
            if (((1) + ((1) + (1))) == (1)) {
                var0 = 1;
            } else {
                if (((1) + ((1) + (1))) == (1)) {
                    var0 = 1;
                } else {
                    if (((1) + ((1) + (1))) == (1)) {
                        var0 = 1;
                    } else {
                        var0 = 1;
                    }
                }
            }
        }
        return var0;
    }
}"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#14
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
double score(List<double> input) {
    double var0;
    if ((1.0) == (input[0])) {
        var0 = 2.0;
    } else {
        var0 = 3.0;
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#15
0
def test_single_condition():
    estimator = ensemble.RandomForestRegressor(n_estimators=2, random_state=1)

    estimator.fit([[1], [2]], [1, 2])

    assembler = assemblers.RandomForestModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.BinNumExpr(
        ast.BinNumExpr(
            ast.SubroutineExpr(ast.NumVal(1.0)),
            ast.SubroutineExpr(
                ast.IfExpr(
                    ast.CompExpr(ast.FeatureRef(0),
                                 ast.NumVal(1.5), ast.CompOpType.LTE),
                    ast.NumVal(1.0), ast.NumVal(2.0))), ast.BinNumOpType.ADD),
        ast.NumVal(0.5), ast.BinNumOpType.MUL)

    assert utils.cmp_exprs(actual, expected)
示例#16
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
function score(input) {
    var var0;
    if ((1.0) === (input[0])) {
        var0 = 2.0;
    } else {
        var0 = 3.0;
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#17
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
function Score([double[]] $InputVector) {
    [double]$var0 = 0
    if ((1) -eq ($InputVector[0])) {
        $var0 = 2
    } else {
        $var0 = 3
    }
    return $var0
}
"""

    interpreter = PowershellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#18
0
def test_count_exprs():
    assert ast.count_exprs(
        ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2),
                       ast.BinNumOpType.ADD)) == 3

    assert ast.count_exprs(ast.ExpExpr(ast.NumVal(2))) == 2

    assert ast.count_exprs(
        ast.VectorVal([ast.NumVal(2),
                       ast.TanhExpr(ast.NumVal(3))])) == 4

    assert ast.count_exprs(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
            ast.NumVal(3),
            ast.NumVal(4),
        )) == 6

    assert ast.count_exprs(ast.NumVal(1)) == 1
示例#19
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
score <- function(input) {
    if ((1.0) == (1.0)) {
        var0 <- c(1.0, 2.0)
    } else {
        var0 <- c(3.0, 4.0)
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#20
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
def score(input)
    if (1.0) == (1.0)
        var0 = [1.0, 2.0]
    else
        var0 = [3.0, 4.0]
    end
    var0
end
"""

    interpreter = RubyInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#21
0
def test_multi_output():
    expr = ast.SubroutineExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
            ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))

    expected_code = """
import numpy as np
def score(input):
    if (1) == (1):
        var0 = np.asarray([1, 2])
    else:
        var0 = np.asarray([3, 4])
    return var0
"""

    interpreter = interpreters.PythonInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#22
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
#include <string.h>
void score(double * input, double * output) {
    double var0[2];
    if ((1.0) == (1.0)) {
        memcpy(var0, (double[]){1.0, 2.0}, 2 * sizeof(double));
    } else {
        memcpy(var0, (double[]){3.0, 4.0}, 2 * sizeof(double));
    }
    memcpy(output, var0, 2 * sizeof(double));
}"""
    interpreter = CInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#23
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
List<double> score(List<double> input) {
    List<double> var0;
    if ((1.0) == (1.0)) {
        var0 = [1.0, 2.0];
    } else {
        var0 = [3.0, 4.0];
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#24
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    let var0: f64;
    if (1.0_f64) == (input[0]) {
        var0 = 2.0_f64;
    } else {
        var0 = 3.0_f64;
    }
    var0
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#25
0
 def kernel_ast(sup_vec_value):
     feature_norm = ast.SqrtExpr(
         ast.BinNumExpr(
             ast.FeatureRef(0),
             ast.FeatureRef(0),
             ast.BinNumOpType.MUL),
         to_reuse=True)
     return ast.BinNumExpr(
         ast.BinNumExpr(
             ast.NumVal(sup_vec_value),
             ast.FeatureRef(0),
             ast.BinNumOpType.MUL),
         ast.IfExpr(
             ast.CompExpr(
                 feature_norm,
                 ast.NumVal(0.0),
                 ast.CompOpType.EQ),
             ast.NumVal(1.0),
             feature_norm),
         ast.BinNumOpType.DIV)
示例#26
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
function score(input) {
    var var0;
    if ((1) == (1)) {
        var0 = [1, 2];
    } else {
        var0 = [3, 4];
    }
    return var0;
}
"""

    interpreter = interpreters.JavascriptInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#27
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
<?php
function score(array $input) {
    $var0 = null;
    if ((1) === ($input[0])) {
        $var0 = 2;
    } else {
        $var0 = 3;
    }
    return $var0;
}
"""

    interpreter = PhpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#28
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
module Model where
score :: [Double] -> Double
score input =
    func0
    where
        func0 =
            if (1.0) == ((input) !! (0)) then
                2.0
            else
                3.0
"""

    interpreter = HaskellInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#29
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
function Score([double[]] $InputVector) {
    [double[]]$var0 = @(0)
    if ((1) -eq (1)) {
        $var0 = @($(1), $(2))
    } else {
        $var0 = @($(3), $(4))
    }
    return $var0
}
"""

    interpreter = PowershellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
示例#30
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = CustomFSharpInterpreter()

    expected_code = """
let score (input : double list) =
    let func0 =
        (3.0) + ((3.0) + (1.0))
    let func1 =
        (2.0) + ((2.0) + (1.0))
    let func2 =
        (1.0) + ((1.0) + (1.0))
    let func3 =
        (0.0) + ((0.0) + (1.0))
    let func4 =
        if (((3.0) + ((3.0) + (func0))) = (3.0)) then
            1.0
        else
            if (((2.0) + ((2.0) + (func1))) = (3.0)) then
                1.0
            else
                if (((1.0) + ((1.0) + (func2))) = (3.0)) then
                    1.0
                else
                    if (((0.0) + ((0.0) + (func3))) = (3.0)) then
                        1.0
                    else
                        1.0
    func4
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)