def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ module Model where score :: [Double] -> Double score input = func1 where func0 = if ((1.0) == (1.0)) then 1.0 else 2.0 func1 = if ((1.0) == ((func0) + (2.0))) then if ((1.0) == ((func0) + (2.0))) then (input) !! (2) else 2.0 else 2.0 """ interpreter = HaskellInterpreter() actual_code = interpreter.interpret(expr) utils.assert_code_equal(actual_code, expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ namespace ML { public static class Model { public static double Score(double[] input) { double var0; if ((1) == (input[0])) { var0 = 2; } else { var0 = 3; } return var0; } } } """ interpreter = CSharpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ <?php function score(array $input) { $var0 = array(); if ((1.0) === (1.0)) { $var0 = array(1.0, 2.0); } else { $var0 = array(3.0, 4.0); } return $var0; } """ interpreter = PhpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ function Score([double[]] $InputVector) { [double]$var0 = 0.0 [double]$var1 = 0.0 if ((1.0) -eq (1.0)) { $var1 = 1.0 } else { $var1 = 2.0 } if ((($var1) + (2.0)) -ge ((1.0) / (2.0))) { $var0 = 1.0 } else { $var0 = $InputVector[0] } return $var0 } """ interpreter = PowershellInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ public class Model { public static double[] score(double[] input) { double[] var0; if ((1.0) == (1.0)) { var0 = new double[] {1.0, 2.0}; } else { var0 = new double[] {3.0, 4.0}; } return var0; } }""" interpreter = interpreters.JavaInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_not_reaching_threshold(): expr = ast.NumVal(1) for _ in range(4): inner = ast.NumVal(1) for __ in range(2): inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) expected_code = """ score <- function(input) { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 <- 1.0 } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 <- 1.0 } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 <- 1.0 } else { if (((1.0) + ((1.0) + (1.0))) == (1.0)) { var0 <- 1.0 } else { var0 <- 1.0 } } } } return(var0) } """ interpreter = CustomRInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ score <- function(input) { if ((1.0) == (1.0)) { var1 <- 1.0 } else { var1 <- 2.0 } if ((1.0) == ((var1) + (2.0))) { if ((1.0) == (1.0)) { var2 <- 1.0 } else { var2 <- 2.0 } if ((1.0) == ((var2) + (2.0))) { var0 <- input[3] } else { var0 <- 2.0 } } else { var0 <- 2.0 } return(var0) } """ interpreter = RInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ double score(List<double> input) { double var0; double var1; if ((1) == (1)) { var1 = 1; } else { var1 = 2; } if (((var1) + (2)) >= ((1) / (2))) { var0 = 1; } else { var0 = input[0]; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ namespace ML { public static class Model { public static double[] Score(double[] input) { double[] var0; if ((1) == (1)) { var0 = new double[2] {1, 2}; } else { var0 = new double[2] {3, 4}; } return var0; } } } """ interpreter = CSharpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_depth_threshold_without_bin_expr(): expr = ast.NumVal(1) for i in range(4): expr = ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = CustomDartInterpreter() expected_code = """ double score(List<double> input) { double var0; if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { if ((1) == (1)) { var0 = 1; } else { var0 = 1; } } } } return var0; } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.SubroutineExpr( ast.IfExpr( ast.CompExpr( ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))) expected_code = """ List<double> score(List<double> input) { List<double> var0; if ((1) == (1)) { var0 = [1, 2]; } else { var0 = [3, 4]; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_num_expr(): expr = ast.BinVectorNumExpr( ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.NumVal(1), ast.BinNumOpType.MUL) interpreter = interpreters.CInterpreter() expected_code = """ #include <string.h> void add_vectors(double *v1, double *v2, int size, double *result) { for(int i = 0; i < size; ++i) result[i] = v1[i] + v2[i]; } void mul_vector_number(double *v1, double num, int size, double *result) { for(int i = 0; i < size; ++i) result[i] = v1[i] * num; } void score(double * input, double * output) { double var0[2]; mul_vector_number((double[]){1.0, 2.0}, 1.0, 2, var0); memcpy(output, var0, 2 * sizeof(double)); }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_tanh_expr(): expr = ast.TanhExpr(ast.NumVal(2.0)) expected_code = """ Module Model Function Tanh(ByVal number As Double) As Double If number > 44.0 Then ' exp(2*x) <= 2^127 Tanh = 1.0 Exit Function End If If number < -44.0 Then Tanh = -1.0 Exit Function End If Tanh = (Math.Exp(2 * number) - 1) / (Math.Exp(2 * number) + 1) End Function Function score(ByRef input_vector() As Double) As Double score = Tanh(2.0) End Function End Module """ interpreter = VisualBasicInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for j in range(4): inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = CustomPythonInterpreter() expected_code = """ def score(input): var1 = (3.0) + ((3.0) + (1.0)) if ((3.0) + ((3.0) + (var1))) == (3.0): var0 = 1.0 else: var2 = (2.0) + ((2.0) + (1.0)) if ((2.0) + ((2.0) + (var2))) == (3.0): var0 = 1.0 else: var3 = (1.0) + ((1.0) + (1.0)) if ((1.0) + ((1.0) + (var3))) == (3.0): var0 = 1.0 else: var4 = (0.0) + ((0.0) + (1.0)) if ((0.0) + ((0.0) + (var4))) == (3.0): var0 = 1.0 else: var0 = 1.0 return var0 """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ def score(input): if (1.0) == (1.0): var1 = 1.0 else: var1 = 2.0 if ((var1) + (2.0)) >= ((1.0) / (2.0)): var0 = 1.0 else: var0 = input[0] return var0 """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) expected_code = """ using static System.Math; namespace ML { public static class Model { public static double Score(double[] input) { return Log1p(2.0); } private static double Log1p(double x) { if (x == 0.0) return 0.0; if (x == -1.0) return double.NegativeInfinity; if (x < -1.0) return double.NaN; double xAbs = Abs(x); if (xAbs < 0.5 * double.Epsilon) return x; if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) return x * (1.0 - x * 0.5); if (xAbs < 0.375) { double[] coeffs = { 0.10378693562743769800686267719098e+1, -0.13364301504908918098766041553133e+0, 0.19408249135520563357926199374750e-1, -0.30107551127535777690376537776592e-2, 0.48694614797154850090456366509137e-3, -0.81054881893175356066809943008622e-4, 0.13778847799559524782938251496059e-4, -0.23802210894358970251369992914935e-5, 0.41640416213865183476391859901989e-6, -0.73595828378075994984266837031998e-7, 0.13117611876241674949152294345011e-7, -0.23546709317742425136696092330175e-8, 0.42522773276034997775638052962567e-9, -0.77190894134840796826108107493300e-10, 0.14075746481359069909215356472191e-10, -0.25769072058024680627537078627584e-11, 0.47342406666294421849154395005938e-12, -0.87249012674742641745301263292675e-13, 0.16124614902740551465739833119115e-13, -0.29875652015665773006710792416815e-14, 0.55480701209082887983041321697279e-15, -0.10324619158271569595141333961932e-15}; return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)); } return Log(1.0 + x); } private static double ChebyshevBroucke(double x, double[] coeffs) { double b0, b1, b2, x2; b2 = b1 = b0 = 0.0; x2 = x * 2; for (int i = coeffs.Length - 1; i >= 0; --i) { b2 = b1; b1 = b0; b0 = x2 * b1 - b2 + coeffs[i]; } return (b0 - b2) * 0.5; } } } """ interpreter = CSharpInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) for j in range(4): inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) interpreter = RInterpreter() interpreter.bin_depth_threshold = 1 interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 6 expected_code = """ score <- function(input) { var1 <- subroutine0(input) if (((3.0) + (var1)) == (3.0)) { var0 <- 1.0 } else { var2 <- subroutine1(input) if (((2.0) + (var2)) == (3.0)) { var0 <- 1.0 } else { var3 <- subroutine2(input) if (((1.0) + (var3)) == (3.0)) { var0 <- 1.0 } else { var4 <- subroutine3(input) if (((0.0) + (var4)) == (3.0)) { var0 <- 1.0 } else { var0 <- 1.0 } } } } return(var0) } subroutine0 <- function(input) { var0 <- (3.0) + (1.0) var1 <- (3.0) + (var0) return((3.0) + (var1)) } subroutine1 <- function(input) { var0 <- (2.0) + (1.0) var1 <- (2.0) + (var0) return((2.0) + (var1)) } subroutine2 <- function(input) { var0 <- (1.0) + (1.0) var1 <- (1.0) + (var0) return((1.0) + (var1)) } subroutine3 <- function(input) { var0 <- (0.0) + (1.0) var1 <- (0.0) + (var0) return((0.0) + (var1)) } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) expected_code = """ def score(input) log1p(2.0) end def log1p(x) if x == 0.0 return 0.0 end if x == -1.0 return -Float::INFINITY end if x < -1.0 return Float::NAN end x_abs = x.abs if x_abs < 0.5 * Float::EPSILON return x end if (x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0) return x * (1.0 - x * 0.5) end if x_abs < 0.375 coeffs = [ 0.10378693562743769800686267719098e+1, -0.13364301504908918098766041553133e+0, 0.19408249135520563357926199374750e-1, -0.30107551127535777690376537776592e-2, 0.48694614797154850090456366509137e-3, -0.81054881893175356066809943008622e-4, 0.13778847799559524782938251496059e-4, -0.23802210894358970251369992914935e-5, 0.41640416213865183476391859901989e-6, -0.73595828378075994984266837031998e-7, 0.13117611876241674949152294345011e-7, -0.23546709317742425136696092330175e-8, 0.42522773276034997775638052962567e-9, -0.77190894134840796826108107493300e-10, 0.14075746481359069909215356472191e-10, -0.25769072058024680627537078627584e-11, 0.47342406666294421849154395005938e-12, -0.87249012674742641745301263292675e-13, 0.16124614902740551465739833119115e-13, -0.29875652015665773006710792416815e-14, 0.55480701209082887983041321697279e-15, -0.10324619158271569595141333961932e-15] return x * (1.0 - x * chebyshev_broucke(x / 0.375, coeffs)) end return Math.log(1.0 + x) end def chebyshev_broucke(x, coeffs) b2 = b1 = b0 = 0.0 x2 = x * 2 coeffs.reverse_each do |i| b2 = b1 b1 = b0 b0 = x2 * b1 - b2 + i end (b0 - b2) * 0.5 end """ interpreter = RubyInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) expected_code = """ Module Model Function ChebyshevBroucke(ByVal x As Double, _ ByRef coeffs() As Double) As Double Dim b2 as Double Dim b1 as Double Dim b0 as Double Dim x2 as Double b2 = 0.0 b1 = 0.0 b0 = 0.0 x2 = x * 2 Dim i as Integer For i = UBound(coeffs) - 1 To 0 Step -1 b2 = b1 b1 = b0 b0 = x2 * b1 - b2 + coeffs(i) Next i ChebyshevBroucke = (b0 - b2) * 0.5 End Function Function Log1p(ByVal x As Double) As Double If x = 0.0 Then Log1p = 0.0 Exit Function End If If x = -1.0 Then On Error Resume Next Log1p = -1.0 / 0.0 Exit Function End If If x < -1.0 Then On Error Resume Next Log1p = 0.0 / 0.0 Exit Function End If Dim xAbs As Double xAbs = Math.Abs(x) If xAbs < 0.5 * 4.94065645841247e-324 Then Log1p = x Exit Function End If If (x > 0.0 AND x < 1e-8) OR (x > -1e-9 AND x < 0.0) Then Log1p = x * (1.0 - x * 0.5) Exit Function End If If xAbs < 0.375 Then Dim coeffs(22) As Double coeffs(0) = 0.10378693562743769800686267719098e+1 coeffs(1) = -0.13364301504908918098766041553133e+0 coeffs(2) = 0.19408249135520563357926199374750e-1 coeffs(3) = -0.30107551127535777690376537776592e-2 coeffs(4) = 0.48694614797154850090456366509137e-3 coeffs(5) = -0.81054881893175356066809943008622e-4 coeffs(6) = 0.13778847799559524782938251496059e-4 coeffs(7) = -0.23802210894358970251369992914935e-5 coeffs(8) = 0.41640416213865183476391859901989e-6 coeffs(9) = -0.73595828378075994984266837031998e-7 coeffs(10) = 0.13117611876241674949152294345011e-7 coeffs(11) = -0.23546709317742425136696092330175e-8 coeffs(12) = 0.42522773276034997775638052962567e-9 coeffs(13) = -0.77190894134840796826108107493300e-10 coeffs(14) = 0.14075746481359069909215356472191e-10 coeffs(15) = -0.25769072058024680627537078627584e-11 coeffs(16) = 0.47342406666294421849154395005938e-12 coeffs(17) = -0.87249012674742641745301263292675e-13 coeffs(18) = 0.16124614902740551465739833119115e-13 coeffs(19) = -0.29875652015665773006710792416815e-14 coeffs(20) = 0.55480701209082887983041321697279e-15 coeffs(21) = -0.10324619158271569595141333961932e-15 Log1p = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)) Exit Function End If Log1p = Math.log(1.0 + x) End Function Function Score(ByRef inputVector() As Double) As Double Score = Log1p(2.0) End Function End Module """ interpreter = VisualBasicInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)