def test_bin_num_expr(): expr = ast.BinNumExpr( ast.BinNumExpr( ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), ast.NumVal(2), ast.BinNumOpType.MUL) expected_code = """ fn score(input: Vec<f64>) -> f64 { ((input[0]) / (-2.0_f64)) * (2.0_f64) } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
class RustExecutor(BaseExecutor): def __init__(self, model): self.model_name = "score" self.model = model self.interpreter = RustInterpreter() assembler_cls = get_assembler_cls(model) self.model_ast = assembler_cls(model).assemble() self.exec_path = None def predict(self, X): exec_args = [str(self.exec_path), *map(utils.format_arg, X)] return utils.predict_from_commandline(exec_args) def prepare(self): if self.model_ast.output_size > 1: execute_code = EXECUTE_VECTOR else: execute_code = EXECUTE_SCALAR executor_code = EXECUTOR_CODE_TPL.format( model_code=self.interpreter.interpret(self.model_ast), execute_code=execute_code) executor_file_name = self._resource_tmp_dir / f"{self.model_name}.rs" utils.write_content_to_file(executor_code, executor_file_name) self.exec_path = self._resource_tmp_dir / self.model_name subprocess.call([ "rustc", str(executor_file_name), "-o", str(self.exec_path) ])
def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) expected_code = """ fn score(input: Vec<f64>) -> f64 { sigmoid(2.0_f64) } fn sigmoid(x: f64) -> f64 { if x < 0.0_f64 { let z: f64 = x.exp(); return z / (1.0_f64 + z); } 1.0_f64 / (1.0_f64 + (-x).exp()) } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ fn score(input: Vec<f64>) -> f64 { let var0: f64; let var1: f64; if (1.0_f64) == (1.0_f64) { var1 = 1.0_f64; } else { var1 = 2.0_f64; } if (1.0_f64) == ((var1) + (2.0_f64)) { let var2: f64; if (1.0_f64) == (1.0_f64) { var2 = 1.0_f64; } else { var2 = 2.0_f64; } if (1.0_f64) == ((var2) + (2.0_f64)) { var0 = input[2]; } else { var0 = 2.0_f64; } } else { var0 = 2.0_f64; } var0 } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ fn score(input: Vec<f64>) -> f64 { let var0: f64; if (1.0_f64) == (input[0]) { var0 = 2.0_f64; } else { var0 = 3.0_f64; } var0 } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_num_expr(): expr = ast.BinVectorNumExpr( ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.NumVal(1), ast.BinNumOpType.MUL) expected_code = """ fn score(input: Vec<f64>) -> Vec<f64> { mul_vector_number(vec![1.0_f64, 2.0_f64], 1.0_f64) } fn add_vectors(v1: Vec<f64>, v2: Vec<f64>) -> Vec<f64> { v1.iter().zip(v2.iter()).map(|(&x, &y)| x + y).collect::<Vec<f64>>() } fn mul_vector_number(v1: Vec<f64>, num: f64) -> Vec<f64> { v1.iter().map(|&i| i * num).collect::<Vec<f64>>() } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_softmax_expr(): expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) expected_code = """ fn score(input: Vec<f64>) -> Vec<f64> { softmax(vec![2.0_f64, 3.0_f64]) } fn softmax(x: Vec<f64>) -> Vec<f64> { let size: usize = x.len(); let m: f64 = x.iter().fold(std::f64::MIN, |a, b| a.max(*b)); let mut exps: Vec<f64> = vec![0.0_f64; size]; let mut s: f64 = 0.0_f64; for (i, &v) in x.iter().enumerate() { exps[i] = (v - m).exp(); s += exps[i]; } exps.iter().map(|&i| i / s).collect::<Vec<f64>>() } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ fn score(input: Vec<f64>) -> Vec<f64> { let var0: Vec<f64>; if (1.0_f64) == (1.0_f64) { var0 = vec![1.0_f64, 2.0_f64]; } else { var0 = vec![3.0_f64, 4.0_f64]; } var0 } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)