示例#1
0
def test_bin_num_expr():
    expr = ast.BinNumExpr(
        ast.BinNumExpr(
            ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV),
        ast.NumVal(2),
        ast.BinNumOpType.MUL)

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    ((input[0]) / (-2.0_f64)) * (2.0_f64)
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#2
0
class RustExecutor(BaseExecutor):

    def __init__(self, model):
        self.model_name = "score"
        self.model = model
        self.interpreter = RustInterpreter()

        assembler_cls = get_assembler_cls(model)
        self.model_ast = assembler_cls(model).assemble()

        self.exec_path = None

    def predict(self, X):
        exec_args = [str(self.exec_path), *map(utils.format_arg, X)]
        return utils.predict_from_commandline(exec_args)

    def prepare(self):
        if self.model_ast.output_size > 1:
            execute_code = EXECUTE_VECTOR
        else:
            execute_code = EXECUTE_SCALAR

        executor_code = EXECUTOR_CODE_TPL.format(
            model_code=self.interpreter.interpret(self.model_ast),
            execute_code=execute_code)

        executor_file_name = self._resource_tmp_dir / f"{self.model_name}.rs"
        utils.write_content_to_file(executor_code, executor_file_name)
        self.exec_path = self._resource_tmp_dir / self.model_name
        subprocess.call([
            "rustc",
            str(executor_file_name),
            "-o",
            str(self.exec_path)
        ])
示例#3
0
def test_sigmoid_expr():
    expr = ast.SigmoidExpr(ast.NumVal(2.0))

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    sigmoid(2.0_f64)
}
fn sigmoid(x: f64) -> f64 {
    if x < 0.0_f64 {
        let z: f64 = x.exp();
        return z / (1.0_f64 + z);
    }
    1.0_f64 / (1.0_f64 + (-x).exp())
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#4
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1),
                         ast.NumVal(1),
                         ast.CompOpType.EQ),
            ast.NumVal(1),
            ast.NumVal(2)),
        ast.NumVal(2),
        ast.BinNumOpType.ADD)
    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    let var0: f64;
    let var1: f64;
    if (1.0_f64) == (1.0_f64) {
        var1 = 1.0_f64;
    } else {
        var1 = 2.0_f64;
    }
    if (1.0_f64) == ((var1) + (2.0_f64)) {
        let var2: f64;
        if (1.0_f64) == (1.0_f64) {
            var2 = 1.0_f64;
        } else {
            var2 = 2.0_f64;
        }
        if (1.0_f64) == ((var2) + (2.0_f64)) {
            var0 = input[2];
        } else {
            var0 = 2.0_f64;
        }
    } else {
        var0 = 2.0_f64;
    }
    var0
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#5
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    expected_code = """
fn score(input: Vec<f64>) -> f64 {
    let var0: f64;
    if (1.0_f64) == (input[0]) {
        var0 = 2.0_f64;
    } else {
        var0 = 3.0_f64;
    }
    var0
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#6
0
def test_bin_vector_num_expr():
    expr = ast.BinVectorNumExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.NumVal(1),
        ast.BinNumOpType.MUL)

    expected_code = """
fn score(input: Vec<f64>) -> Vec<f64> {
    mul_vector_number(vec![1.0_f64, 2.0_f64], 1.0_f64)
}
fn add_vectors(v1: Vec<f64>, v2: Vec<f64>) -> Vec<f64> {
    v1.iter().zip(v2.iter()).map(|(&x, &y)| x + y).collect::<Vec<f64>>()
}
fn mul_vector_number(v1: Vec<f64>, num: f64) -> Vec<f64> {
    v1.iter().map(|&i| i * num).collect::<Vec<f64>>()
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#7
0
def test_softmax_expr():
    expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)])

    expected_code = """
fn score(input: Vec<f64>) -> Vec<f64> {
    softmax(vec![2.0_f64, 3.0_f64])
}
fn softmax(x: Vec<f64>) -> Vec<f64> {
    let size: usize = x.len();
    let m: f64 = x.iter().fold(std::f64::MIN, |a, b| a.max(*b));
    let mut exps: Vec<f64> = vec![0.0_f64; size];
    let mut s: f64 = 0.0_f64;
    for (i, &v) in x.iter().enumerate() {
        exps[i] = (v - m).exp();
        s += exps[i];
    }
    exps.iter().map(|&i| i / s).collect::<Vec<f64>>()
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
示例#8
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(
            ast.NumVal(1),
            ast.NumVal(1),
            ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
fn score(input: Vec<f64>) -> Vec<f64> {
    let var0: Vec<f64>;
    if (1.0_f64) == (1.0_f64) {
        var0 = vec![1.0_f64, 2.0_f64];
    } else {
        var0 = vec![3.0_f64, 4.0_f64];
    }
    var0
}
"""

    interpreter = RustInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)