def test_single_condition(): estimator = ensemble.RandomForestRegressor(n_estimators=2, random_state=1) estimator.fit([[1], [2]], [1, 2]) assembler = assemblers.RandomForestModelAssembler(estimator) actual = assembler.assemble() expected = ast.BinNumExpr( ast.BinNumExpr( ast.SubroutineExpr(ast.NumVal(1.0)), ast.SubroutineExpr( ast.IfExpr( ast.CompExpr(ast.FeatureRef(0), ast.NumVal(1.5), ast.CompOpType.LTE), ast.NumVal(1.0), ast.NumVal(2.0))), ast.BinNumOpType.ADD), ast.NumVal(0.5), ast.BinNumOpType.MUL) assert utils.cmp_exprs(actual, expected)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ score <- function(input) { if ((1.0) == (input[1])) { var0 <- 2.0 } else { var0 <- 3.0 } return(var0) } """ interpreter = RInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ function Score([double[]] $InputVector) { [double]$var0 = 0 if ((1) -eq ($InputVector[0])) { $var0 = 2 } else { $var0 = 3 } return $var0 } """ interpreter = PowershellInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) interpreter = interpreters.GoInterpreter() expected_code = """ func score(input []float64) float64 { var var0 float64 if (1.0) == (input[0]) { var0 = 2.0 } else { var0 = 3.0 } return var0 }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ def score(input) if (1.0) == (1.0) var0 = [1.0, 2.0] else var0 = [3.0, 4.0] end var0 end """ interpreter = RubyInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_count_exprs(): assert ast.count_exprs( ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD)) == 3 assert ast.count_exprs(ast.ExpExpr(ast.NumVal(2))) == 2 assert ast.count_exprs( ast.VectorVal([ast.NumVal(2), ast.TanhExpr(ast.NumVal(3))])) == 4 assert ast.count_exprs( ast.IfExpr( ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT), ast.NumVal(3), ast.NumVal(4), )) == 6 assert ast.count_exprs(ast.NumVal(1)) == 1
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ double score(List<double> input) { double var0; if ((1.0) == (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ #include <string.h> void score(double * input, double * output) { double var0[2]; if ((1) == (1)) { memcpy(var0, (double[]){1, 2}, 2 * sizeof(double)); } else { memcpy(var0, (double[]){3, 4}, 2 * sizeof(double)); } memcpy(output, var0, 2 * sizeof(double)); }""" interpreter = interpreters.CInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_not_reaching_threshold(): expr = ast.NumVal(1) for _ in range(4): inner = ast.NumVal(1) for __ in range(2): inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr( inner, ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = interpreters.JavaInterpreter() interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ public class Model { public static double score(double[] input) { double var0; if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 = 1.0; } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 = 1.0; } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 = 1.0; } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 = 1.0; } else { var0 = 1.0; } } } } return var0; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ score <- function(input) { if ((1.0) == (1.0)) { var0 <- c(1.0, 2.0) } else { var0 <- c(3.0, 4.0) } return(var0) } """ interpreter = RInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ function score(input) { var var0; if ((1.0) === (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } """ interpreter = JavascriptInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ fn score(input: Vec<f64>) -> f64 { let var0: f64; if (1.0_f64) == (input[0]) { var0 = 2.0_f64; } else { var0 = 3.0_f64; } var0 } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ function score(input) { var var0; if ((1) == (1)) { var0 = [1, 2]; } else { var0 = [3, 4]; } return var0; } """ interpreter = interpreters.JavascriptInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def kernel_ast(sup_vec_value): feature_norm = ast.SqrtExpr( ast.BinNumExpr( ast.FeatureRef(0), ast.FeatureRef(0), ast.BinNumOpType.MUL), to_reuse=True) return ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(sup_vec_value), ast.FeatureRef(0), ast.BinNumOpType.MUL), ast.IfExpr( ast.CompExpr( feature_norm, ast.NumVal(0.0), ast.CompOpType.EQ), ast.NumVal(1.0), feature_norm), ast.BinNumOpType.DIV)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ function Score([double[]] $InputVector) { [double[]]$var0 = @(0) if ((1) -eq (1)) { $var0 = @($(1), $(2)) } else { $var0 = @($(3), $(4)) } return $var0 } """ interpreter = PowershellInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ module Model where score :: [Double] -> Double score input = func0 where func0 = if (1.0) == ((input) !! (0)) then 2.0 else 3.0 """ interpreter = HaskellInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ <?php function score(array $input) { $var0 = null; if ((1) === ($input[0])) { $var0 = 2; } else { $var0 = 3; } return $var0; } """ interpreter = PhpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ List<double> score(List<double> input) { List<double> var0; if ((1.0) == (1.0)) { var0 = [1.0, 2.0]; } else { var0 = [3.0, 4.0]; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ def score(input): if (1) == (1): var0 = [1, 2] else: var0 = [3, 4] return var0 """ interpreter = interpreters.PythonInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ func score(input []float64) []float64 { var var0 []float64 if (1.0) == (1.0) { var0 = []float64{1.0, 2.0} } else { var0 = []float64{3.0, 4.0} } return var0 } """ interpreter = GoInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for j in range(4): inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = CustomFSharpInterpreter() expected_code = """ let score (input : double list) = let func0 = (3.0) + ((3.0) + (1.0)) let func1 = (2.0) + ((2.0) + (1.0)) let func2 = (1.0) + ((1.0) + (1.0)) let func3 = (0.0) + ((0.0) + (1.0)) let func4 = if (((3.0) + ((3.0) + (func0))) = (3.0)) then 1.0 else if (((2.0) + ((2.0) + (func1))) = (3.0)) then 1.0 else if (((1.0) + ((1.0) + (func2))) = (3.0)) then 1.0 else if (((0.0) + ((0.0) + (func3))) = (3.0)) then 1.0 else 1.0 func4 """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for i in range(4): inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = CustomRInterpreter() expected_code = """ score <- function(input) { var1 <- (1) + ((1) + (1)) if (((1) + ((1) + (var1))) == (1)) { var0 <- 1 } else { var2 <- (1) + ((1) + (1)) if (((1) + ((1) + (var2))) == (1)) { var0 <- 1 } else { var3 <- (1) + ((1) + (1)) if (((1) + ((1) + (var3))) == (1)) { var0 <- 1 } else { var4 <- (1) + ((1) + (1)) if (((1) + ((1) + (var4))) == (1)) { var0 <- 1 } else { var0 <- 1 } } } } return(var0) } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ <?php function score(array $input) { $var0 = array(); if ((1.0) === (1.0)) { $var0 = array(1.0, 2.0); } else { $var0 = array(3.0, 4.0); } return $var0; } """ interpreter = PhpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ module Model where score :: [Double] -> [Double] score input = func0 where func0 = if (1.0) == (1.0) then [1.0, 2.0] else [3.0, 4.0] """ interpreter = HaskellInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ public class Model { public static double[] score(double[] input) { double[] var0; if ((1.0) == (1.0)) { var0 = new double[] {1.0, 2.0}; } else { var0 = new double[] {3.0, 4.0}; } return var0; } }""" interpreter = JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ Module Model Function score(ByRef input_vector() As Double) As Double Dim var0 As Double If (1) == (input_vector(0)) Then var0 = 2 Else var0 = 3 End If score = var0 End Function End Module """ interpreter = VisualBasicInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_depth_threshold_without_bin_expr(): expr = ast.NumVal(1) for i in range(4): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = interpreters.JavaInterpreter() interpreter.bin_depth_threshold = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ public class Model { public static double score(double[] input) { double var0; if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { var0 = 1; } } } } return var0; } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_not_reaching_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for i in range(2): inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr( inner, ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = CustomDartInterpreter() expected_code = """ double score(List<double> input) { double var0; if (((1) + ((1) + (1))) == (1)) { var0 = 1; } else { if (((1) + ((1) + (1))) == (1)) { var0 = 1; } else { if (((1) + ((1) + (1))) == (1)) { var0 = 1; } else { if (((1) + ((1) + (1))) == (1)) { var0 = 1; } else { var0 = 1; } } } } return var0; } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { double var0; if ((1.0) == (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.SubroutineExpr( ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))) expected_code = """ func score(input []float64) []float64 { var var0 []float64 if (1) == (1) { var0 = []float64{1, 2} } else { var0 = []float64{3, 4} } return var0 }""" interpreter = interpreters.GoInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)