def test_softmax_expr(): expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) expected_code = """ import 'dart:math'; List<double> score(List<double> input) { return softmax([2.0, 3.0]); } List<double> softmax(List<double> x) { int size = x.length; List<double> result = new List<double>.filled(size, 0.0); double maxElem = x.reduce(max); double sum = 0.0; for (int i = 0; i < size; ++i) { result[i] = exp(x[i] - maxElem); sum += result[i]; } for (int i = 0; i < size; ++i) result[i] /= sum; return result; } """ interpreter = DartInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_dependable_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) expected_code = """ double score(List<double> input) { double var0; double var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if (((var1) + (2.0)) >= ((1.0) / (2.0))) { var0 = 1.0; } else { var0 = input[0]; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_vector_num_expr(): expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.NumVal(1), ast.BinNumOpType.MUL) expected_code = """ List<double> score(List<double> input) { return mulVectorNumber([1.0, 2.0], 1.0); } List<double> addVectors(List<double> v1, List<double> v2) { List<double> result = new List<double>(v1.length); for (int i = 0; i < v1.length; i++) { result[i] = v1[i] + v2[i]; } return result; } List<double> mulVectorNumber(List<double> v1, double num) { List<double> result = new List<double>(v1.length); for (int i = 0; i < v1.length; i++) { result[i] = v1[i] * num; } return result; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def __init__(self, model): self.model_name = "score" self.model = model self.interpreter = DartInterpreter() assembler_cls = get_assembler_cls(model) self.model_ast = assembler_cls(model).assemble() self.script_path = None
def test_abs_expr(): expr = ast.AbsExpr(ast.NumVal(-1.0)) expected_code = """ double score(List<double> input) { return (-1.0).abs(); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_atan_expr(): expr = ast.AtanExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return atan(2.0); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_raw_array(): expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) expected_code = """ List<double> score(List<double> input) { return [3.0, 4.0]; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_pow_expr(): expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return (pow(2.0, 3.0)).toDouble(); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log_expr(): expr = ast.LogExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return log(2.0); } """ interpreter = DartInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_bin_num_expr(): expr = ast.BinNumExpr( ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), ast.NumVal(2), ast.BinNumOpType.MUL) expected_code = """ double score(List<double> input) { return ((input[0]) / (-2.0)) * (2.0); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) expected_code = """ import 'dart:math'; double score(List<double> input) { double var0; var0 = exp(1.0); return (var0) / (var0); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
class DartExecutor(BaseExecutor): def __init__(self, model): self.model_name = "score" self.model = model self.interpreter = DartInterpreter() assembler_cls = get_assembler_cls(model) self.model_ast = assembler_cls(model).assemble() self.script_path = None def predict(self, X): exec_args = [ "dart", str(self.script_path), *map(utils.format_arg, X) ] return utils.predict_from_commandline(exec_args) def prepare(self): if self.model_ast.output_size > 1: print_code = EXECUTE_AND_PRINT_VECTOR else: print_code = EXECUTE_AND_PRINT_SCALAR executor_code = EXECUTOR_CODE_TPL.format( model_code=self.interpreter.interpret(self.model_ast), print_code=print_code) self.script_path = self._resource_tmp_dir / f"{self.model_name}.dart" utils.write_content_to_file(executor_code, self.script_path)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ double score(List<double> input) { double var0; double var1; if ((1) == (1)) { var1 = 1; } else { var1 = 2; } if ((1) == ((var1) + (2))) { double var2; if ((1) == (1)) { var2 = 1; } else { var2 = 2; } if ((1) == ((var2) + (2))) { var0 = input[2]; } else { var0 = 2; } } else { var0 = 2; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return sigmoid(2.0); } double sigmoid(double x) { if (x < 0.0) { double z = exp(x); return z / (1.0 + z); } return 1.0 / (1.0 + exp(-x)); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_tanh_expr(): expr = ast.TanhExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return tanh(2.0); } double tanh(double x) { if (x > 22.0) return 1.0; if (x < -22.0) return -1.0; return ((exp(2*x) - 1)/(exp(2*x) + 1)); } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ double score(List<double> input) { double var0; if ((1.0) == (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ List<double> score(List<double> input) { List<double> var0; if ((1.0) == (1.0)) { var0 = [1.0, 2.0]; } else { var0 = [3.0, 4.0]; } return var0; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_tanh_expr(): expr = ast.TanhExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return tanh(2.0); } double tanh(double x) { // Implementation is taken from // https://github.com/golang/go/blob/master/src/math/tanh.go double z; z = x.abs(); if (z > 0.440148459655565271479942397125e+2) { if (x < 0) { return -1.0; } return 1.0; } if (z >= 0.625) { z = 1 - 2 / (exp(2 * z) + 1); if (x < 0) { z = -z; } return z; } if (x == 0) { return 0.0; } double s; s = x * x; z = x + x * s * ((-0.964399179425052238628e+0 * s + -0.992877231001918586564e+2) * s + -0.161468768441708447952e+4) / (((s + 0.112811678491632931402e+3) * s + 0.223548839060100448583e+4) * s + 0.484406305325125486048e+4); return z; } """ interpreter = DartInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_log1p_expr(): expr = ast.Log1pExpr(ast.NumVal(2.0)) expected_code = """ import 'dart:math'; double score(List<double> input) { return log1p(2.0); } double log1p(double x) { if (x == 0.0) return 0.0; if (x == -1.0) return double.negativeInfinity; if (x < -1.0) return double.nan; double xAbs = x.abs(); if (xAbs < 0.5 * 4.94065645841247e-324) return x; if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) return x * (1.0 - x * 0.5); if (xAbs < 0.375) { List<double> coeffs = [ 0.10378693562743769800686267719098e+1, -0.13364301504908918098766041553133e+0, 0.19408249135520563357926199374750e-1, -0.30107551127535777690376537776592e-2, 0.48694614797154850090456366509137e-3, -0.81054881893175356066809943008622e-4, 0.13778847799559524782938251496059e-4, -0.23802210894358970251369992914935e-5, 0.41640416213865183476391859901989e-6, -0.73595828378075994984266837031998e-7, 0.13117611876241674949152294345011e-7, -0.23546709317742425136696092330175e-8, 0.42522773276034997775638052962567e-9, -0.77190894134840796826108107493300e-10, 0.14075746481359069909215356472191e-10, -0.25769072058024680627537078627584e-11, 0.47342406666294421849154395005938e-12, -0.87249012674742641745301263292675e-13, 0.16124614902740551465739833119115e-13, -0.29875652015665773006710792416815e-14, 0.55480701209082887983041321697279e-15, -0.10324619158271569595141333961932e-15]; return x * (1.0 - x * chebyshevBroucke(x / 0.375, coeffs)); } return log(1.0 + x); } double chebyshevBroucke(double x, List<double> coeffs) { double b0, b1, b2, x2; b2 = b1 = b0 = 0.0; x2 = x * 2; for (int i = coeffs.length - 1; i >= 0; --i) { b2 = b1; b1 = b0; b0 = x2 * b1 - b2 + coeffs[i]; } return (b0 - b2) * 0.5; } """ interpreter = DartInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)