Exemple #1
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)

    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)

    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))

    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
module Model where
score :: [Double] -> Double
score input =
    func1
    where
        func0 =
            if ((1.0) == (1.0)) then
                1.0
            else
                2.0
        func1 =
            if ((1.0) == ((func0) + (2.0))) then
                if ((1.0) == ((func0) + (2.0))) then
                    (input) !! (2)
                else
                    2.0
            else
                2.0
"""

    interpreter = HaskellInterpreter()
    actual_code = interpreter.interpret(expr)
    utils.assert_code_equal(actual_code, expected_code)
Exemple #2
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2),
        ast.NumVal(3))

    expected_code = """
namespace ML {
    public static class Model {
        public static double Score(double[] input) {
            double var0;
            if ((1) == (input[0])) {
                var0 = 2;
            } else {
                var0 = 3;
            }
            return var0;
        }
    }
}
"""

    interpreter = CSharpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #3
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(
            ast.NumVal(1),
            ast.NumVal(1),
            ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
<?php
function score(array $input) {
    $var0 = array();
    if ((1.0) === (1.0)) {
        $var0 = array(1.0, 2.0);
    } else {
        $var0 = array(3.0, 4.0);
    }
    return $var0;
}
"""

    interpreter = PhpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #4
0
def test_dependable_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1),
                         ast.NumVal(1),
                         ast.CompOpType.EQ),
            ast.NumVal(1),
            ast.NumVal(2)),
        ast.NumVal(2),
        ast.BinNumOpType.ADD)

    right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
    bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)

    expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

    expected_code = """
function Score([double[]] $InputVector) {
    [double]$var0 = 0.0
    [double]$var1 = 0.0
    if ((1.0) -eq (1.0)) {
        $var1 = 1.0
    } else {
        $var1 = 2.0
    }
    if ((($var1) + (2.0)) -ge ((1.0) / (2.0))) {
        $var0 = 1.0
    } else {
        $var0 = $InputVector[0]
    }
    return $var0
}
"""

    interpreter = PowershellInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #5
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(
            ast.NumVal(1),
            ast.NumVal(1),
            ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
public class Model {
    public static double[] score(double[] input) {
        double[] var0;
        if ((1.0) == (1.0)) {
            var0 = new double[] {1.0, 2.0};
        } else {
            var0 = new double[] {3.0, 4.0};
        }
        return var0;
    }
}"""

    interpreter = interpreters.JavaInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #6
0
def test_deep_mixed_exprs_not_reaching_threshold():
    expr = ast.NumVal(1)
    for _ in range(4):
        inner = ast.NumVal(1)
        for __ in range(2):
            inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    expected_code = """
score <- function(input) {
    if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
        var0 <- 1.0
    } else {
        if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
            var0 <- 1.0
        } else {
            if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
                var0 <- 1.0
            } else {
                if (((1.0) + ((1.0) + (1.0))) == (1.0)) {
                    var0 <- 1.0
                } else {
                    var0 <- 1.0
                }
            }
        }
    }
    return(var0)
}
"""

    interpreter = CustomRInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #7
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)
    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)
    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))
    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
score <- function(input) {
    if ((1.0) == (1.0)) {
        var1 <- 1.0
    } else {
        var1 <- 2.0
    }
    if ((1.0) == ((var1) + (2.0))) {
        if ((1.0) == (1.0)) {
            var2 <- 1.0
        } else {
            var2 <- 2.0
        }
        if ((1.0) == ((var2) + (2.0))) {
            var0 <- input[3]
        } else {
            var0 <- 2.0
        }
    } else {
        var0 <- 2.0
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #8
0
def test_dependable_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1),
                         ast.NumVal(1),
                         ast.CompOpType.EQ),
            ast.NumVal(1),
            ast.NumVal(2)),
        ast.NumVal(2),
        ast.BinNumOpType.ADD)

    right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
    bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)

    expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

    expected_code = """
double score(List<double> input) {
    double var0;
    double var1;
    if ((1) == (1)) {
        var1 = 1;
    } else {
        var1 = 2;
    }
    if (((var1) + (2)) >= ((1) / (2))) {
        var0 = 1;
    } else {
        var0 = input[0];
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #9
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
namespace ML {
    public static class Model {
        public static double[] Score(double[] input) {
            double[] var0;
            if ((1) == (1)) {
                var0 = new double[2] {1, 2};
            } else {
                var0 = new double[2] {3, 4};
            }
            return var0;
        }
    }
}
"""

    interpreter = CSharpInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #10
0
def test_depth_threshold_without_bin_expr():
    expr = ast.NumVal(1)
    for i in range(4):
        expr = ast.IfExpr(
            ast.CompExpr(
                ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1),
            expr)

    interpreter = CustomDartInterpreter()

    expected_code = """
double score(List<double> input) {
    double var0;
    if ((1) == (1)) {
        var0 = 1;
    } else {
        if ((1) == (1)) {
            var0 = 1;
        } else {
            if ((1) == (1)) {
                var0 = 1;
            } else {
                if ((1) == (1)) {
                    var0 = 1;
                } else {
                    var0 = 1;
                }
            }
        }
    }
    return var0;
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #11
0
def test_multi_output():
    expr = ast.SubroutineExpr(
        ast.IfExpr(
            ast.CompExpr(
                ast.NumVal(1),
                ast.NumVal(1),
                ast.CompOpType.EQ),
            ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
            ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))

    expected_code = """
List<double> score(List<double> input) {
    List<double> var0;
    if ((1) == (1)) {
        var0 = [1, 2];
    } else {
        var0 = [3, 4];
    }
    return var0;
}
"""

    interpreter = DartInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #12
0
def test_bin_vector_num_expr():
    expr = ast.BinVectorNumExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.NumVal(1),
        ast.BinNumOpType.MUL)

    interpreter = interpreters.CInterpreter()

    expected_code = """
#include <string.h>
void add_vectors(double *v1, double *v2, int size, double *result) {
    for(int i = 0; i < size; ++i)
        result[i] = v1[i] + v2[i];
}
void mul_vector_number(double *v1, double num, int size, double *result) {
    for(int i = 0; i < size; ++i)
        result[i] = v1[i] * num;
}
void score(double * input, double * output) {
    double var0[2];
    mul_vector_number((double[]){1.0, 2.0}, 1.0, 2, var0);
    memcpy(output, var0, 2 * sizeof(double));
}"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #13
0
def test_tanh_expr():
    expr = ast.TanhExpr(ast.NumVal(2.0))

    expected_code = """
Module Model
Function Tanh(ByVal number As Double) As Double
    If number > 44.0 Then  ' exp(2*x) <= 2^127
        Tanh = 1.0
        Exit Function
    End If
    If number < -44.0 Then
        Tanh = -1.0
        Exit Function
    End If
    Tanh = (Math.Exp(2 * number) - 1) / (Math.Exp(2 * number) + 1)
End Function
Function score(ByRef input_vector() As Double) As Double
    score = Tanh(2.0)
End Function
End Module
"""

    interpreter = VisualBasicInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #14
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = CustomPythonInterpreter()

    expected_code = """
def score(input):
    var1 = (3.0) + ((3.0) + (1.0))
    if ((3.0) + ((3.0) + (var1))) == (3.0):
        var0 = 1.0
    else:
        var2 = (2.0) + ((2.0) + (1.0))
        if ((2.0) + ((2.0) + (var2))) == (3.0):
            var0 = 1.0
        else:
            var3 = (1.0) + ((1.0) + (1.0))
            if ((1.0) + ((1.0) + (var3))) == (3.0):
                var0 = 1.0
            else:
                var4 = (0.0) + ((0.0) + (1.0))
                if ((0.0) + ((0.0) + (var4))) == (3.0):
                    var0 = 1.0
                else:
                    var0 = 1.0
    return var0
    """

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #15
0
def test_dependable_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)
    right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
    bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)
    expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

    expected_code = """
def score(input):
    if (1.0) == (1.0):
        var1 = 1.0
    else:
        var1 = 2.0
    if ((var1) + (2.0)) >= ((1.0) / (2.0)):
        var0 = 1.0
    else:
        var0 = input[0]
    return var0
"""

    interpreter = PythonInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #16
0
def test_log1p_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    expected_code = """
using static System.Math;
namespace ML {
    public static class Model {
        public static double Score(double[] input) {
            return Log1p(2.0);
        }
        private static double Log1p(double x) {
            if (x == 0.0)
                return 0.0;
            if (x == -1.0)
                return double.NegativeInfinity;
            if (x < -1.0)
                return double.NaN;
            double xAbs = Abs(x);
            if (xAbs < 0.5 * double.Epsilon)
                return x;
            if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0))
                return x * (1.0 - x * 0.5);
            if (xAbs < 0.375) {
                double[] coeffs = {
                     0.10378693562743769800686267719098e+1,
                    -0.13364301504908918098766041553133e+0,
                     0.19408249135520563357926199374750e-1,
                    -0.30107551127535777690376537776592e-2,
                     0.48694614797154850090456366509137e-3,
                    -0.81054881893175356066809943008622e-4,
                     0.13778847799559524782938251496059e-4,
                    -0.23802210894358970251369992914935e-5,
                     0.41640416213865183476391859901989e-6,
                    -0.73595828378075994984266837031998e-7,
                     0.13117611876241674949152294345011e-7,
                    -0.23546709317742425136696092330175e-8,
                     0.42522773276034997775638052962567e-9,
                    -0.77190894134840796826108107493300e-10,
                     0.14075746481359069909215356472191e-10,
                    -0.25769072058024680627537078627584e-11,
                     0.47342406666294421849154395005938e-12,
                    -0.87249012674742641745301263292675e-13,
                     0.16124614902740551465739833119115e-13,
                    -0.29875652015665773006710792416815e-14,
                     0.55480701209082887983041321697279e-15,
                    -0.10324619158271569595141333961932e-15};
                return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs));
            }
            return Log(1.0 + x);
        }
        private static double ChebyshevBroucke(double x, double[] coeffs) {
            double b0, b1, b2, x2;
            b2 = b1 = b0 = 0.0;
            x2 = x * 2;
            for (int i = coeffs.Length - 1; i >= 0; --i) {
                b2 = b1;
                b1 = b0;
                b0 = x2 * b1 - b2 + coeffs[i];
            }
            return (b0 - b2) * 0.5;
        }
    }
}
"""

    interpreter = CSharpInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #17
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = RInterpreter()
    interpreter.bin_depth_threshold = 1
    interpreter.ast_size_check_frequency = 2
    interpreter.ast_size_per_subroutine_threshold = 6

    expected_code = """
score <- function(input) {
    var1 <- subroutine0(input)
    if (((3.0) + (var1)) == (3.0)) {
        var0 <- 1.0
    } else {
        var2 <- subroutine1(input)
        if (((2.0) + (var2)) == (3.0)) {
            var0 <- 1.0
        } else {
            var3 <- subroutine2(input)
            if (((1.0) + (var3)) == (3.0)) {
                var0 <- 1.0
            } else {
                var4 <- subroutine3(input)
                if (((0.0) + (var4)) == (3.0)) {
                    var0 <- 1.0
                } else {
                    var0 <- 1.0
                }
            }
        }
    }
    return(var0)
}
subroutine0 <- function(input) {
    var0 <- (3.0) + (1.0)
    var1 <- (3.0) + (var0)
    return((3.0) + (var1))
}
subroutine1 <- function(input) {
    var0 <- (2.0) + (1.0)
    var1 <- (2.0) + (var0)
    return((2.0) + (var1))
}
subroutine2 <- function(input) {
    var0 <- (1.0) + (1.0)
    var1 <- (1.0) + (var0)
    return((1.0) + (var1))
}
subroutine3 <- function(input) {
    var0 <- (0.0) + (1.0)
    var1 <- (0.0) + (var0)
    return((0.0) + (var1))
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #18
0
def test_log1p_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    expected_code = """
def score(input)
    log1p(2.0)
end
def log1p(x)
    if x == 0.0
        return 0.0
    end
    if x == -1.0
        return -Float::INFINITY
    end
    if x < -1.0
        return Float::NAN
    end
    x_abs = x.abs
    if x_abs < 0.5 * Float::EPSILON
        return x
    end
    if (x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)
        return x * (1.0 - x * 0.5)
    end
    if x_abs < 0.375
        coeffs = [
             0.10378693562743769800686267719098e+1,
            -0.13364301504908918098766041553133e+0,
             0.19408249135520563357926199374750e-1,
            -0.30107551127535777690376537776592e-2,
             0.48694614797154850090456366509137e-3,
            -0.81054881893175356066809943008622e-4,
             0.13778847799559524782938251496059e-4,
            -0.23802210894358970251369992914935e-5,
             0.41640416213865183476391859901989e-6,
            -0.73595828378075994984266837031998e-7,
             0.13117611876241674949152294345011e-7,
            -0.23546709317742425136696092330175e-8,
             0.42522773276034997775638052962567e-9,
            -0.77190894134840796826108107493300e-10,
             0.14075746481359069909215356472191e-10,
            -0.25769072058024680627537078627584e-11,
             0.47342406666294421849154395005938e-12,
            -0.87249012674742641745301263292675e-13,
             0.16124614902740551465739833119115e-13,
            -0.29875652015665773006710792416815e-14,
             0.55480701209082887983041321697279e-15,
            -0.10324619158271569595141333961932e-15]
        return x * (1.0 - x * chebyshev_broucke(x / 0.375, coeffs))
    end
    return Math.log(1.0 + x)
end
def chebyshev_broucke(x, coeffs)
    b2 = b1 = b0 = 0.0
    x2 = x * 2
    coeffs.reverse_each do |i|
        b2 = b1
        b1 = b0
        b0 = x2 * b1 - b2 + i
    end
    (b0 - b2) * 0.5
end
"""

    interpreter = RubyInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
Exemple #19
0
def test_log1p_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    expected_code = """
Module Model
Function ChebyshevBroucke(ByVal x As Double, _
                          ByRef coeffs() As Double) As Double
    Dim b2 as Double
    Dim b1 as Double
    Dim b0 as Double
    Dim x2 as Double
    b2 = 0.0
    b1 = 0.0
    b0 = 0.0
    x2 = x * 2
    Dim i as Integer
    For i = UBound(coeffs) - 1 To 0 Step -1
        b2 = b1
        b1 = b0
        b0 = x2 * b1 - b2 + coeffs(i)
    Next i
    ChebyshevBroucke = (b0 - b2) * 0.5
End Function
Function Log1p(ByVal x As Double) As Double
    If x = 0.0 Then
        Log1p = 0.0
        Exit Function
    End If
    If x = -1.0 Then
        On Error Resume Next
        Log1p = -1.0 / 0.0
        Exit Function
    End If
    If x < -1.0 Then
        On Error Resume Next
        Log1p = 0.0 / 0.0
        Exit Function
    End If
    Dim xAbs As Double
    xAbs = Math.Abs(x)
    If xAbs < 0.5 * 4.94065645841247e-324 Then
        Log1p = x
        Exit Function
    End If
    If (x > 0.0 AND x < 1e-8) OR (x > -1e-9 AND x < 0.0) Then
        Log1p = x * (1.0 - x * 0.5)
        Exit Function
    End If
    If xAbs < 0.375 Then
        Dim coeffs(22) As Double
        coeffs(0)  =  0.10378693562743769800686267719098e+1
        coeffs(1)  = -0.13364301504908918098766041553133e+0
        coeffs(2)  =  0.19408249135520563357926199374750e-1
        coeffs(3)  = -0.30107551127535777690376537776592e-2
        coeffs(4)  =  0.48694614797154850090456366509137e-3
        coeffs(5)  = -0.81054881893175356066809943008622e-4
        coeffs(6)  =  0.13778847799559524782938251496059e-4
        coeffs(7)  = -0.23802210894358970251369992914935e-5
        coeffs(8)  =  0.41640416213865183476391859901989e-6
        coeffs(9)  = -0.73595828378075994984266837031998e-7
        coeffs(10) =  0.13117611876241674949152294345011e-7
        coeffs(11) = -0.23546709317742425136696092330175e-8
        coeffs(12) =  0.42522773276034997775638052962567e-9
        coeffs(13) = -0.77190894134840796826108107493300e-10
        coeffs(14) =  0.14075746481359069909215356472191e-10
        coeffs(15) = -0.25769072058024680627537078627584e-11
        coeffs(16) =  0.47342406666294421849154395005938e-12
        coeffs(17) = -0.87249012674742641745301263292675e-13
        coeffs(18) =  0.16124614902740551465739833119115e-13
        coeffs(19) = -0.29875652015665773006710792416815e-14
        coeffs(20) =  0.55480701209082887983041321697279e-15
        coeffs(21) = -0.10324619158271569595141333961932e-15
        Log1p = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs))
        Exit Function
    End If
    Log1p = Math.log(1.0 + x)
End Function
Function Score(ByRef inputVector() As Double) As Double
    Score = Log1p(2.0)
End Function
End Module
"""

    interpreter = VisualBasicInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)