def test_sigmoid_expr(): expr = ast.SigmoidExpr(ast.NumVal(2.0)) interpreter = JavascriptInterpreter() expected_code = """ function score(input) { return sigmoid(2.0); } function sigmoid(x) { if (x < 0.0) { let z = Math.exp(x); return z / (1.0 + z); } return 1.0 / (1.0 + Math.exp(-x)); } """ utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_if_expr(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), ast.NumVal(2), ast.NumVal(3)) expected_code = """ function score(input) { var var0; if ((1.0) === (input[0])) { var0 = 2.0; } else { var0 = 3.0; } return var0; } """ interpreter = JavascriptInterpreter() assert_code_equal(interpreter.interpret(expr), expected_code)
def test_nested_condition(): left = ast.BinNumExpr( ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD) bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) expected_code = """ function score(input) { var var0; var var1; if ((1.0) == (1.0)) { var1 = 1.0; } else { var1 = 2.0; } if ((1.0) == ((var1) + (2.0))) { var var2; if ((1.0) == (1.0)) { var2 = 1.0; } else { var2 = 2.0; } if ((1.0) == ((var2) + (2.0))) { var0 = input[2]; } else { var0 = 2.0; } } else { var0 = 2.0; } return var0; } """ interpreter = JavascriptInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)
def test_multi_output(): expr = ast.IfExpr( ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ function score(input) { var var0; if ((1.0) == (1.0)) { var0 = [1.0, 2.0]; } else { var0 = [3.0, 4.0]; } return var0; } """ interpreter = JavascriptInterpreter() utils.assert_code_equal(interpreter.interpret(expr), expected_code)