예제 #1
0
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
        ast.BinNumOpType.ADD)

    expected_code = """
fn score(input: Vec<f64>) -> Vec<f64> {
    add_vectors(vec![1.0_f64, 2.0_f64], vec![3.0_f64, 4.0_f64])
}
fn add_vectors(v1: Vec<f64>, v2: Vec<f64>) -> Vec<f64> {
    v1.iter().zip(v2.iter()).map(|(&x, &y)| x + y).collect::<Vec<f64>>()
}
fn mul_vector_number(v1: Vec<f64>, num: f64) -> Vec<f64> {
    v1.iter().map(|&i| i * num).collect::<Vec<f64>>()
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #2
0
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(ast.VectorVal([ast.NumVal(1),
                                            ast.NumVal(2)]),
                             ast.VectorVal([ast.NumVal(3),
                                            ast.NumVal(4)]),
                             ast.BinNumOpType.ADD)

    expected_code = """
module Model where
addVectors :: [Double] -> [Double] -> [Double]
addVectors v1 v2 = zipWith (+) v1 v2
mulVectorNumber :: [Double] -> Double -> [Double]
mulVectorNumber v1 num = [i * num | i <- v1]
score :: [Double] -> [Double]
score input =
    addVectors ([1.0, 2.0]) ([3.0, 4.0])
"""

    interpreter = HaskellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #3
0
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(ast.VectorVal([ast.NumVal(1),
                                            ast.NumVal(2)]),
                             ast.VectorVal([ast.NumVal(3),
                                            ast.NumVal(4)]),
                             ast.BinNumOpType.ADD)

    expected_code = """
def score(input)
    add_vectors([1.0, 2.0], [3.0, 4.0])
end
def add_vectors(v1, v2)
    v1.zip(v2).map { |x, y| x + y }
end
def mul_vector_number(v1, num)
    v1.map { |i| i * num }
end
"""

    interpreter = RubyInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #4
0
def test_multi_class():
    estimator = ensemble.RandomForestClassifier(
        n_estimators=2, random_state=13)

    estimator.fit([[1], [2], [3]], [1, -1, 1])

    assembler = assemblers.RandomForestModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.BinVectorNumExpr(
        ast.BinVectorExpr(
            ast.IfExpr(
                ast.CompExpr(
                    ast.FeatureRef(0),
                    ast.NumVal(1.5),
                    ast.CompOpType.LTE),
                ast.VectorVal([
                    ast.NumVal(0.0),
                    ast.NumVal(1.0)]),
                ast.VectorVal([
                    ast.NumVal(1.0),
                    ast.NumVal(0.0)])),
            ast.IfExpr(
                ast.CompExpr(
                    ast.FeatureRef(0),
                    ast.NumVal(2.5),
                    ast.CompOpType.LTE),
                ast.VectorVal([
                    ast.NumVal(1.0),
                    ast.NumVal(0.0)]),
                ast.VectorVal([
                    ast.NumVal(0.0),
                    ast.NumVal(1.0)])),
            ast.BinNumOpType.ADD),
        ast.NumVal(0.5),
        ast.BinNumOpType.MUL)

    assert utils.cmp_exprs(actual, expected)
예제 #5
0
파일: test_c.py 프로젝트: rspadim/m2cgen
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
        ast.BinNumOpType.ADD)

    interpreter = interpreters.CInterpreter()

    expected_code = """
#include <string.h>
void add_vectors(double *v1, double *v2, int size, double *result) {
    for(int i = 0; i < size; ++i)
        result[i] = v1[i] + v2[i];
}
void mul_vector_number(double *v1, double num, int size, double *result) {
    for(int i = 0; i < size; ++i)
        result[i] = v1[i] * num;
}
void score(double * input, double * output) {
    double var0[2];
    add_vectors((double[]){1.0, 2.0}, (double[]){3.0, 4.0}, 2, var0);
    memcpy(output, var0, 2 * sizeof(double));
}"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #6
0
파일: test_ast.py 프로젝트: mrshu/m2cgen
def test_count_all_exprs_types():
    expr = ast.BinVectorNumExpr(
        ast.BinVectorExpr(
            ast.VectorVal([
                ast.ExpExpr(ast.NumVal(2)),
                ast.PowExpr(ast.NumVal(2), ast.NumVal(3)),
                ast.TanhExpr(ast.NumVal(1)),
                ast.BinNumExpr(ast.NumVal(0), ast.FeatureRef(0),
                               ast.BinNumOpType.ADD)
            ]),
            ast.VectorVal([
                ast.NumVal(1),
                ast.NumVal(2),
                ast.NumVal(3),
                ast.FeatureRef(1)
            ]), ast.BinNumOpType.SUB),
        ast.SubroutineExpr(
            ast.IfExpr(
                ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
                ast.NumVal(3),
                ast.NumVal(4),
            )), ast.BinNumOpType.MUL)

    assert ast.count_exprs(expr) == 24
예제 #7
0
EXPR_WITH_ALL_EXPRS = ast.BinVectorNumExpr(
    ast.BinVectorExpr(
        ast.VectorVal([
            ast.AbsExpr(ast.NumVal(-2)),
            ast.AtanExpr(ast.NumVal(2)),
            ast.ExpExpr(ast.NumVal(2)),
            ast.LogExpr(ast.NumVal(2)),
            ast.Log1pExpr(ast.NumVal(2)),
            ast.SigmoidExpr(ast.NumVal(2)),
            ast.SqrtExpr(ast.NumVal(2)),
            ast.PowExpr(ast.NumVal(2), ast.NumVal(3)),
            ast.TanhExpr(ast.NumVal(1)),
            ast.BinNumExpr(
                ast.NumVal(0),
                ast.FeatureRef(0),
                ast.BinNumOpType.ADD)
        ]),
        ast.IdExpr(
            ast.SoftmaxExpr([
                ast.NumVal(1),
                ast.NumVal(2),
                ast.NumVal(3),
                ast.NumVal(4),
                ast.NumVal(5),
                ast.NumVal(6),
                ast.NumVal(7),
                ast.NumVal(8),
                ast.NumVal(9),
                ast.FeatureRef(1)
            ])),
        ast.BinNumOpType.SUB),
    ast.IfExpr(