Ejemplo n.º 1
0
def test_reused_expr():
    reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True)
    expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV)

    expected_code = """
score <- function(input) {
    var0 <- exp(1.0)
    return((var0) / (var0))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 2
0
def test_bin_vector_num_expr():
    expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1),
                                               ast.NumVal(2)]), ast.NumVal(1),
                                ast.BinNumOpType.MUL)

    expected_code = """
score <- function(input) {
    return((c(1.0, 2.0)) * (1.0))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 3
0
def test_bin_num_expr():
    expr = ast.BinNumExpr(
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2),
                       ast.BinNumOpType.DIV), ast.NumVal(2),
        ast.BinNumOpType.MUL)

    expected_code = """
score <- function(input) {
    return(((input[1]) / (-2.0)) * (2.0))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 4
0
def test_bin_vector_expr():
    expr = ast.BinVectorExpr(
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
        ast.BinNumOpType.ADD)

    expected_code = """
score <- function(input) {
    return((c(1, 2)) + (c(3, 4)))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 5
0
def test_softmax_expr():
    expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)])

    expected_code = """
softmax <- function (x) {
    m <- max(x)
    exps <- exp(x - m)
    s <- sum(exps)
    return(exps / s)
}
score <- function(input) {
    return(softmax(c(2.0, 3.0)))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 6
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
score <- function(input) {
    if ((1.0) == (input[1])) {
        var0 <- 2.0
    } else {
        var0 <- 3.0
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 7
0
def test_sigmoid_expr():
    expr = ast.SigmoidExpr(ast.NumVal(2.0))

    expected_code = """
sigmoid <- function(x) {
    if (x < 0.0) {
        z <- exp(x)
        return(z / (1.0 + z))
    }
    return(1.0 / (1.0 + exp(-x)))
}
score <- function(input) {
    return(sigmoid(2.0))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 8
0
def test_subroutine():
    expr = ast.BinNumExpr(
        ast.FeatureRef(0),
        ast.SubroutineExpr(
            ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2),
                           ast.BinNumOpType.ADD)), ast.BinNumOpType.MUL)

    expected_code = """
score <- function(input) {
    return((input[1]) * (subroutine0(input)))
}
subroutine0 <- function(input) {
    return((1) + (2))
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 9
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1),
                         ast.NumVal(1),
                         ast.CompOpType.EQ),
            ast.NumVal(1),
            ast.NumVal(2)),
        ast.NumVal(2),
        ast.BinNumOpType.ADD)

    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)

    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))

    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
score <- function(input) {
    if ((1.0) == (1.0)) {
        var1 <- 1.0
    } else {
        var1 <- 2.0
    }
    if ((1.0) == ((var1) + (2.0))) {
        if ((1.0) == (1.0)) {
            var2 <- 1.0
        } else {
            var2 <- 2.0
        }
        if ((1.0) == ((var2) + (2.0))) {
            var0 <- input[3]
        } else {
            var0 <- 2.0
        }
    } else {
        var0 <- 2.0
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 10
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
score <- function(input) {
    if ((1.0) == (1.0)) {
        var0 <- c(1.0, 2.0)
    } else {
        var0 <- c(3.0, 4.0)
    }
    return(var0)
}
"""

    interpreter = RInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
Ejemplo n.º 11
0
class RExecutor(BaseExecutor):
    def __init__(self, model):
        self.model_name = "score"
        self.model = model
        self.interpreter = RInterpreter()

        assembler_cls = get_assembler_cls(model)
        self.model_ast = assembler_cls(model).assemble()

        self.script_path = None

    def predict(self, X):
        exec_args = [
            "Rscript", "--vanilla",
            str(self.script_path), *map(utils.format_arg, X)
        ]
        return utils.predict_from_commandline(exec_args)

    def prepare(self):
        executor_code = EXECUTOR_CODE_TPL.format(
            model_code=self.interpreter.interpret(self.model_ast))

        self.script_path = self._resource_tmp_dir / f"{self.model_name}.r"
        utils.write_content_to_file(executor_code, self.script_path)
Ejemplo n.º 12
0
def test_deep_mixed_exprs_exceeding_threshold():
    expr = ast.NumVal(1)
    for i in range(4):
        inner = ast.NumVal(1)
        for j in range(4):
            inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

        expr = ast.IfExpr(
            ast.CompExpr(inner, ast.NumVal(j), ast.CompOpType.EQ),
            ast.NumVal(1), expr)

    interpreter = RInterpreter()
    interpreter.bin_depth_threshold = 1
    interpreter.ast_size_check_frequency = 2
    interpreter.ast_size_per_subroutine_threshold = 6

    expected_code = """
score <- function(input) {
    var1 <- subroutine0(input)
    if (((3.0) + (var1)) == (3.0)) {
        var0 <- 1.0
    } else {
        var2 <- subroutine1(input)
        if (((2.0) + (var2)) == (3.0)) {
            var0 <- 1.0
        } else {
            var3 <- subroutine2(input)
            if (((1.0) + (var3)) == (3.0)) {
                var0 <- 1.0
            } else {
                var4 <- subroutine3(input)
                if (((0.0) + (var4)) == (3.0)) {
                    var0 <- 1.0
                } else {
                    var0 <- 1.0
                }
            }
        }
    }
    return(var0)
}
subroutine0 <- function(input) {
    var0 <- (3.0) + (1.0)
    var1 <- (3.0) + (var0)
    return((3.0) + (var1))
}
subroutine1 <- function(input) {
    var0 <- (2.0) + (1.0)
    var1 <- (2.0) + (var0)
    return((2.0) + (var1))
}
subroutine2 <- function(input) {
    var0 <- (1.0) + (1.0)
    var1 <- (1.0) + (var0)
    return((1.0) + (var1))
}
subroutine3 <- function(input) {
    var0 <- (0.0) + (1.0)
    var1 <- (0.0) + (var0)
    return((0.0) + (var1))
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)