def test_maybe_sqr_output_transform(): estimator = lgb.LGBMRegressor(n_estimators=2, random_state=1, max_depth=1, reg_sqrt=True, objective="regression_l1") utils.get_regression_model_trainer()(estimator) assembler = LightGBMModelAssembler(estimator) actual = assembler.assemble() raw_output = ast.IdExpr(ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.FeatureRef(12), ast.NumVal(9.725), ast.CompOpType.GT), ast.NumVal(4.569350528717041), ast.NumVal(4.663526439666748)), ast.IfExpr( ast.CompExpr(ast.FeatureRef(12), ast.NumVal(11.655), ast.CompOpType.GT), ast.NumVal(-0.04462450027465819), ast.NumVal(0.033305134773254384)), ast.BinNumOpType.ADD), to_reuse=True) expected = ast.BinNumExpr(ast.AbsExpr(raw_output), raw_output, ast.BinNumOpType.MUL) assert utils.cmp_exprs(actual, expected)
def test_maybe_sqr_output_transform(): estimator = lightgbm.LGBMRegressor(n_estimators=2, random_state=1, max_depth=1, reg_sqrt=True, objective="regression_l1") utils.get_regression_model_trainer()(estimator) assembler = assemblers.LightGBMModelAssembler(estimator) actual = assembler.assemble() raw_output = ast.IdExpr(ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.FeatureRef(12), ast.NumVal(9.905), ast.CompOpType.GT), ast.NumVal(4.5658116817), ast.NumVal(4.6620790482)), ast.IfExpr( ast.CompExpr(ast.FeatureRef(12), ast.NumVal(9.77), ast.CompOpType.GT), ast.NumVal(-0.0340889740), ast.NumVal(0.0543687153)), ast.BinNumOpType.ADD), to_reuse=True) expected = ast.BinNumExpr(ast.AbsExpr(raw_output), raw_output, ast.BinNumOpType.MUL) assert utils.cmp_exprs(actual, expected)
def test_count_all_exprs_types(): expr = ast.BinVectorNumExpr( ast.BinVectorExpr( ast.VectorVal([ ast.AbsExpr(ast.NumVal(-2)), ast.ExpExpr(ast.NumVal(2)), ast.SqrtExpr(ast.NumVal(2)), ast.PowExpr(ast.NumVal(2), ast.NumVal(3)), ast.TanhExpr(ast.NumVal(1)), ast.BinNumExpr( ast.NumVal(0), ast.FeatureRef(0), ast.BinNumOpType.ADD) ]), ast.IdExpr( ast.VectorVal([ ast.NumVal(1), ast.NumVal(2), ast.NumVal(3), ast.NumVal(4), ast.NumVal(5), ast.FeatureRef(1) ])), ast.BinNumOpType.SUB), ast.IfExpr( ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT), ast.NumVal(3), ast.NumVal(4), ), ast.BinNumOpType.MUL) assert ast.count_exprs(expr) == 31
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ let score (input : double list) = abs (-1.0) """ interpreter = FSharpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ def score(input): return abs(-1.0) """ interpreter = PythonInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ def score(input) (-1.0).abs() end """ interpreter = RubyInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ score <- function(input) { return(abs(-1.0)) } """ interpreter = RInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ double score(List<double> input) { return (-1.0).abs(); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ function score(input) { return Math.abs(-1.0); } """ interpreter = JavascriptInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ function Score([double[]] $InputVector) { return [math]::Abs(-1.0) } """ interpreter = PowershellInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ fn score(input: Vec<f64>) -> f64 { f64::abs(-1.0_f64) } """ interpreter = RustInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ module Model where score :: [Double] -> Double score input = abs (-1.0) """ interpreter = HaskellInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ <?php function score(array $input) { return abs(-1.0); } """ interpreter = PhpInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) interpreter = GoInterpreter() expected_code = """ import "math" func score(input []float64) float64 { return math.Abs(-1.0) }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) interpreter = interpreters.CInterpreter() expected_code = """ #include <math.h> double score(double * input) { return fabs(-1.0); }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ Module Model Function Score(ByRef inputVector() As Double) As Double Score = Math.Abs(-1.0) End Function End Module """ interpreter = VisualBasicInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) interpreter = JavaInterpreter() expected_code = """ public class Model { public static double score(double[] input) { return Math.abs(-1.0); } }""" utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ using static System.Math; namespace ML { public static class Model { public static double Score(double[] input) { return Abs(-1.0); } } } """ interpreter = CSharpInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_abs_fallback_expr(): expr = ast.AbsExpr(ast.NumVal(-2.0)) interpreter = CInterpreter() interpreter.abs_function_name = NotImplemented expected_code = """ double score(double * input) { double var0; double var1; var1 = -2.0; if ((var1) < (0.0)) { var0 = (0.0) - (var1); } else { var0 = var1; } return var0; } """ assert_code_equal(interpreter.interpret(expr), expected_code)
def atan(expr): expr = ast.IdExpr(expr, to_reuse=True) expr_abs = ast.AbsExpr(expr, to_reuse=True) expr_reduced = ast.IdExpr(ast.IfExpr( utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), utils.div(ast.NumVal(1.0), expr_abs), ast.IfExpr( utils.gt(expr_abs, ast.NumVal(0.66)), utils.div(utils.sub(expr_abs, ast.NumVal(1.0)), utils.add(expr_abs, ast.NumVal(1.0))), expr_abs)), to_reuse=True) P0 = ast.NumVal(-8.750608600031904122785e-01) P1 = ast.NumVal(1.615753718733365076637e+01) P2 = ast.NumVal(7.500855792314704667340e+01) P3 = ast.NumVal(1.228866684490136173410e+02) P4 = ast.NumVal(6.485021904942025371773e+01) Q0 = ast.NumVal(2.485846490142306297962e+01) Q1 = ast.NumVal(1.650270098316988542046e+02) Q2 = ast.NumVal(4.328810604912902668951e+02) Q3 = ast.NumVal(4.853903996359136964868e+02) Q4 = ast.NumVal(1.945506571482613964425e+02) expr2 = utils.mul(expr_reduced, expr_reduced, to_reuse=True) z = utils.mul( expr2, utils.div( utils.sub( utils.mul( expr2, utils.sub( utils.mul( expr2, utils.sub( utils.mul(expr2, utils.sub(utils.mul(expr2, P0), P1)), P2)), P3)), P4), utils.add( Q4, utils.mul( expr2, utils.add( Q3, utils.mul( expr2, utils.add( Q2, utils.mul( expr2, utils.add( Q1, utils.mul(expr2, utils.add(Q0, expr2))))))))))) z = utils.add(utils.mul(expr_reduced, z), expr_reduced) ret = utils.mul( z, ast.IfExpr(utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), ast.NumVal(-1.0), ast.NumVal(1.0))) ret = utils.add( ret, ast.IfExpr( utils.lte(expr_abs, ast.NumVal(0.66)), ast.NumVal(0.0), ast.IfExpr(utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), ast.NumVal(1.570796326794896680463661649), ast.NumVal(0.7853981633974483402318308245)))) ret = utils.mul( ret, ast.IfExpr(utils.lt(expr, ast.NumVal(0.0)), ast.NumVal(-1.0), ast.NumVal(1.0))) return ret
def _maybe_sqr_transform(self, expr): if "sqrt" in self.objective_config_parts: expr = ast.IdExpr(expr, to_reuse=True) return utils.mul(ast.AbsExpr(expr), expr) else: return expr
def test_count_exprs_exclude_list(): assert ast.count_exprs( ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD), exclude_list={ast.BinExpr, ast.NumVal} ) == 0 assert ast.count_exprs( ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD), exclude_list={ast.BinNumExpr} ) == 2 EXPR_WITH_ALL_EXPRS = ast.BinVectorNumExpr( ast.BinVectorExpr( ast.VectorVal([ ast.AbsExpr(ast.NumVal(-2)), ast.AtanExpr(ast.NumVal(2)), ast.ExpExpr(ast.NumVal(2)), ast.LogExpr(ast.NumVal(2)), ast.Log1pExpr(ast.NumVal(2)), ast.SigmoidExpr(ast.NumVal(2)), ast.SqrtExpr(ast.NumVal(2)), ast.PowExpr(ast.NumVal(2), ast.NumVal(3)), ast.TanhExpr(ast.NumVal(1)), ast.BinNumExpr( ast.NumVal(0), ast.FeatureRef(0), ast.BinNumOpType.ADD) ]), ast.IdExpr( ast.SoftmaxExpr([