def test_softmax_expr(): expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) expected_code = """ #include <string.h> void softmax(double *x, int size, double *result) { double max = x[0]; for (int i = 1; i < size; ++i) { if (x[i] > max) max = x[i]; } double sum = 0.0; for (int i = 0; i < size; ++i) { result[i] = exp(x[i] - max); sum += result[i]; } for (int i = 0; i < size; ++i) result[i] /= sum; } void score(double * input, double * output) { double var0[2]; softmax((double[]){2.0, 3.0}, 2, var0); memcpy(output, var0, 2 * sizeof(double)); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ double score(double * input) { double var0; double var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if (((var1) + (2.0)) >= ((1.0) / (2.0))) { var0 = 1.0; } else { var0 = input[0]; } return var0; } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
class CExecutor(BaseExecutor): def __init__(self, model): self.model_name = "score" self.model = model self.interpreter = CInterpreter() assembler_cls = get_assembler_cls(model) self.model_ast = assembler_cls(model).assemble() self.exec_path = None def predict(self, X): exec_args = [str(self.exec_path), *map(utils.format_arg, X)] return utils.predict_from_commandline(exec_args) def prepare(self): if self.model_ast.output_size > 1: print_code = EXECUTE_AND_PRINT_VECTOR_TPL.format( size=self.model_ast.output_size) else: print_code = EXECUTE_AND_PRINT_SCALAR executor_code = EXECUTOR_CODE_TPL.format( model_code=self.interpreter.interpret(self.model_ast), print_code=print_code) file_name = self._resource_tmp_dir / f"{self.model_name}.c" utils.write_content_to_file(executor_code, file_name) self.exec_path = self._resource_tmp_dir / self.model_name flags = ["-std=c99", "-lm"] subprocess.call( ["gcc", str(file_name), "-o", str(self.exec_path), *flags])
def test_bin_vector_expr(): expr = ast.BinVectorExpr(ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]), ast.BinNumOpType.ADD) expected_code = """ #include <string.h> void add_vectors(double *v1, double *v2, int size, double *result) { for(int i = 0; i < size; ++i) result[i] = v1[i] + v2[i]; } void mul_vector_number(double *v1, double num, int size, double *result) { for(int i = 0; i < size; ++i) result[i] = v1[i] * num; } void score(double * input, double * output) { double var0[2]; add_vectors((double[]){1.0, 2.0}, (double[]){3.0, 4.0}, 2, var0); memcpy(output, var0, 2 * sizeof(double)); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_raw_array(): expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) expected_code = """ #include <string.h> void score(double * input, double * output) { memcpy(output, (double[]){3.0, 4.0}, 2 * sizeof(double)); }""" interpreter = CInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_atan_expr(): expr = ast.AtanExpr(ast.NumVal(2.0)) expected_code = """ #include <math.h> double score(double * input) { return atan(2.0); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_pow_expr(): expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0)) expected_code = """ #include <math.h> double score(double * input) { return pow(2.0, 3.0); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) interpreter = CInterpreter() expected_code = """ #include <math.h> double score(double * input) { return log1p(2.0); }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) interpreter = CInterpreter() expected_code = """ #include <math.h> double score(double * input) { return fabs(-1.0); }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_num_expr(): expr = ast.BinNumExpr( ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), ast.NumVal(2), ast.BinNumOpType.MUL) interpreter = CInterpreter() expected_code = """ double score(double * input) { return ((input[0]) / (-2.0)) * (2.0); }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) expected_code = """ #include <math.h> double score(double * input) { double var0; var0 = exp(1.0); return (var0) / (var0); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) interpreter = CInterpreter() expected_code = """ double score(double * input) { double var0; if ((1.0) == (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) expected_code = """ #include <math.h> double sigmoid(double x) { if (x < 0.0) { double z = exp(x); return z / (1.0 + z); } return 1.0 / (1.0 + exp(-x)); } double score(double * input) { return sigmoid(2.0); } """ interpreter = CInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ #include <string.h> void score(double * input, double * output) { double var0[2]; if ((1.0) == (1.0)) { memcpy(var0, (double[]){1.0, 2.0}, 2 * sizeof(double)); } else { memcpy(var0, (double[]){3.0, 4.0}, 2 * sizeof(double)); } memcpy(output, var0, 2 * sizeof(double)); }""" interpreter = CInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_fallback_expr(): expr = ast.AbsExpr(ast.NumVal(-2.0)) interpreter = CInterpreter() interpreter.abs_function_name = NotImplemented expected_code = """ double score(double * input) { double var0; double var1; var1 = -2.0; if ((var1) < (0.0)) { var0 = (0.0) - (var1); } else { var0 = var1; } return var0; } """ assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ double score(double * input) { double var0; double var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if ((1.0) == ((var1) + (2.0))) { double var2; if ((1.0) == (1.0)) { var2 = 1.0; } else { var2 = 2.0; } if ((1.0) == ((var2) + (2.0))) { var0 = input[2]; } else { var0 = 2.0; } } else { var0 = 2.0; } return var0; }""" interpreter = CInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)