def test__cps_call(self): call_ast = CallAst(VarAst('foo'), [LiteralAst(1), LiteralAst(2)]) cps_ast = _cps_call(call_ast, lambda x: x) expected_ast = CallAst(VarAst('foo'), [ LambdaAst( '', [cps_ast.args[0].params[0]], VarAst(cps_ast.args[0].params[0]), ), LiteralAst(1), LiteralAst(2) ]) self.assertEqual(cps_ast, expected_ast)
def test_to_cps(self): js_raw_ast = JsAst("aa") cps_ast = _cps_js_raw(js_raw_ast, lambda x: x) self.assertEqual(cps_ast, js_raw_ast) atom_ast = LiteralAst(1.0) cps_ast = to_cps(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast) let_ast = LetAst([], LiteralAst(False)) cps_ast = to_cps(let_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) prog_ast = ProgAst([]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) prog_ast = ProgAst([LiteralAst(1)]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(1)) prog_ast = ProgAst([LiteralAst(1), LiteralAst(2)]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, ProgAst([LiteralAst(1), LiteralAst(2)])) if_ast = IfAst(LiteralAst(1), LiteralAst(2), LiteralAst(3)) cps_ast: CallAst = to_cps(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst(LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(3)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast) lambda_ast = LambdaAst('', ['x', 'y'], LiteralAst(1)) cps_ast = to_cps(lambda_ast, lambda x: x) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [LiteralAst(1)])) self.assertEqual(cps_ast, expected_ast) binary_ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) cps_ast = to_cps(binary_ast, lambda x: x) self.assertEqual(cps_ast, binary_ast) parse = Parser(TokenStream(InputStream("a = foo(10);"))) cps_ast = to_cps(parse(), lambda x: x) expected_ast = CallAst(VarAst('foo'), [ LambdaAst( '', [cps_ast.args[0].params[0]], AssignAst(VarAst('a'), VarAst(cps_ast.args[0].params[0]))), LiteralAst(10) ]) self.assertEqual(cps_ast, expected_ast)
def _cps_if(if_ast: IfAst, k: Callable[[Ast], Ast]) -> CallAst: """ The previous version of cps_if passed the k on compiling both branches, resulting in massive code growth for consecutive if-s. The (quite expensive) fix is to wrap the rest of the program in a continuation, and we'll transform the "if" node into an IIFE which receives that continuation as an argument. The body of this IIFE will just call the continuation with the result of the if expression. The following is a sample transformation: a = if foo then 1 else 2; print(a); (λ (ifcont){ if foo then ifcont(1) else ifcont(2); })(λ (ifret){ a = ifret; print(a); }); """ def cps_cond(cond_ast: Ast) -> Ast: def cps_then_and_else(result: Ast) -> Ast: return CallAst(VarAst(if_continuation), [result]) return IfAst(cond_ast, to_cps(if_ast.then, cps_then_and_else), to_cps(if_ast.else_, cps_then_and_else)) if_continuation = gensym("I") return CallAst( LambdaAst('', [if_continuation], to_cps(if_ast.cond, cps_cond)), [_make_continuation(k)])
def _cps_lambda(lambda_ast: LambdaAst, k: Callable[[Ast], Ast]) -> Ast: continuation = gensym("K") body = to_cps( lambda_ast.body, lambda lambda_body: CallAst(VarAst(continuation), [lambda_body])) return k( LambdaAst(lambda_ast.name, [continuation] + lambda_ast.params, body))
def loop(args: List[Ast], i: int) -> Ast: def arg_callback(arg: Ast) -> Ast: args.append(arg) return loop(args, i + 1) if i == len(call_ast.args): return CallAst(func, args) return to_cps(call_ast.args[i], arg_callback)
def main(): with open(sys.argv[1]) as file: code = file.read() parser = Parser(TokenStream(InputStream(code))) cps_code = to_cps(parser(), lambda ast: CallAst( VarAst('β_TOPLEVEL'), [ast], )) print(cps_code)
def test__cps_lambda(self): # lambda (x, y) 1 lambda_ast = LambdaAst('', ['x', 'y'], LiteralAst(1)) cps_ast = _cps_lambda(lambda_ast, lambda x: x) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [LiteralAst(1)])) self.assertEqual(cps_ast, expected_ast) # lambda (x, y) x + y lambda_ast = LambdaAst('', ['x', 'y'], BinaryAst('+', VarAst('x'), VarAst('y'))) cps_ast = _cps_lambda(lambda_ast, lambda x: x) # expected result: lambda (continue, args) continue(body) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [BinaryAst('+', VarAst('x'), VarAst('y'))])) self.assertEqual(cps_ast, expected_ast)
def _js_let(ast: LetAst) -> str: if not ast.vardefs: return _to_js(ast.body) # immediately invoked function expression iife = CallAst( LambdaAst('', [ast.vardefs[0].name], LetAst(ast.vardefs[1:], ast.body)), [ast.vardefs[0].define or LiteralAst(False)]) return f'({_to_js(iife)})'
def _parse_call(self, func: Ast) -> CallAst: """ func_name is parsed by callee before calling parse_call, so parse_call only need to parse func args :param func: :return: """ return CallAst( func, self._delimited('(', ')', ',', self._parse_expression))
def _cps_let(let_ast: LetAst, k: Callable[[Ast], Ast]) -> Ast: if not let_ast.vardefs: return to_cps(let_ast.body, k) return to_cps( CallAst( LambdaAst('', [let_ast.vardefs[0].name], LetAst(let_ast.vardefs[1:], let_ast.body)), [ let_ast.vardefs[0].define if let_ast.vardefs[0].define else LiteralAst(False) ]), k)
def main(): with open(sys.argv[1]) as file: code = file.read() parser = Parser(TokenStream(InputStream(code))) ast = parser() ast = to_cps(ast, lambda ast: CallAst(VarAst('β_TOPLEVEL'), [ast])) # print(ast) ast = Optimizer().optimize(ast) # print(ast) js_code = to_js(ast) print(js_code)
def _optimize_call_ast(self, ast: CallAst) -> Ast: func = ast.func # the func part of CallAst is anonymous lambda. So the CallAst is IIFE. if isinstance(func, LambdaAst) and (not func.name): # the CallAst is inside an lambda function if isinstance(self.closure, LambdaAst): self.changes += 1 return self._unwrap_iife(ast) func = self._optimize_aux(func) args = [self._optimize_aux(arg) for arg in ast.args] return CallAst(func, args)
def test__cps_if(self): if_ast = IfAst(LiteralAst(1), LiteralAst(2), LiteralAst(3)) cps_ast = _cps_if(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst(LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(3)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast) if_ast = IfAst( LiteralAst(1), LiteralAst(2), LiteralAst(False), ) cps_ast = _cps_if(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst( LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(False)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast)
def test__cps_let(self): # let_ast = LetAst( # [VarDefAst('a', LiteralAst(1)), VarDefAst('b', LiteralAst("a"))], # LiteralAst(False)) let_ast = LetAst( [], LiteralAst(False), ) cps_ast = _cps_let(let_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) let_ast = LetAst([VarDefAst('a', LiteralAst(1))], VarAst('a')) cps_ast = _cps_let(let_ast, lambda x: x) self.assertEqual( cps_ast, CallAst( LambdaAst( '', [cps_ast.func.params[0], 'a'], CallAst(VarAst(cps_ast.func.params[0]), [VarAst('a')])), [ LambdaAst('', [cps_ast.args[0].params[0]], VarAst(cps_ast.args[0].params[0])), LiteralAst(1) ]))
def parser() -> Ast: ast = self._maybe_binary(self._parse_atom(), 0) # relational operator short-circuit implementation if isinstance(ast, BinaryAst): # left || right -> (lambda (left) { # if left then left else right})(left) binary_ast = cast(BinaryAst, ast) if binary_ast.operator == '||': iife_param = gensym('left') ast = CallAst( LambdaAst( '', [iife_param], IfAst( VarAst(iife_param), VarAst(iife_param), binary_ast.right)), [binary_ast.left]) elif binary_ast.operator == '&&': ast = IfAst(binary_ast.left, binary_ast.right, LiteralAst(False)) return ast
def _parse_let(self) -> Union[CallAst, LetAst]: """ When it is a named let, if an arg is not followed by some expression, then a false value is assigned to the arg by the parser. :return: """ self._skip_kw('let') if self._token_stream.peek().type == 'var': name = self._token_stream.next().value vardefs = self._delimited('(', ')', ',', self._parse_vardef) varnames = [vardef.name for vardef in vardefs] defines = [vardef.define if vardef.define else LiteralAst(False) for vardef in vardefs] return CallAst( LambdaAst( name, varnames, self._parse_expression()), defines) vardefs = self._delimited('(', ')', ',', self._parse_vardef) body = self._parse_expression() return LetAst(vardefs, body)
def test_evaluate(self): ast = LiteralAst(1.0) environment = Environment() evaluate(ast, environment, lambda value: self.assertEqual(value, 1.0)) ast = LiteralAst(True) environment = Environment() evaluate(ast, environment, self.assertTrue) ast = LiteralAst(False) environment = Environment() evaluate(ast, environment, self.assertFalse) ast = LiteralAst("aaa") evaluate(ast, Environment(), lambda value: self.assertEqual(value, "aaa")) ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 3.0)) ast = ProgAst([]) evaluate(ast, Environment(), self.assertFalse) ast = ProgAst([LiteralAst(1)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = ProgAst([LiteralAst(1), LiteralAst(2)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 2.0)) ast = AssignAst(LiteralAst(1), LiteralAst("a")) with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) ast = ProgAst([AssignAst(VarAst('a'), LiteralAst("foo")), VarAst('a')]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, "foo")) ast = AssignAst(VarAst("a"), LiteralAst("foo")) with self.assertRaises(Exception): evaluate(ast, Environment(Environment()), lambda value: value) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst(1)], ) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = CallAst(LambdaAst("", ["a"], VarAst("a")), [LiteralAst("abc")]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, "abc")) # # (λ loop (n) if n > 0 then n + loop(n - 1) else 0) (10) ast = CallAst( LambdaAst( "loop", ["n"], IfAst( BinaryAst(">", VarAst("n"), LiteralAst(0)), BinaryAst( "+", VarAst("n"), CallAst(VarAst("loop"), [BinaryAst('-', VarAst('n'), LiteralAst(1))])), LiteralAst(0))), [LiteralAst(10)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 55.0)) # # let (x) x; ast = LetAst([VarDefAst("x", None)], VarAst("x")) evaluate(ast, Environment(), self.assertFalse) # # let (x = 2, y = x + 1, z = x + y) x + y + z ast = LetAst([ VarDefAst("x", LiteralAst(2)), VarDefAst("y", BinaryAst("+", VarAst("x"), LiteralAst(1))), VarDefAst("z", BinaryAst("+", VarAst("x"), VarAst("y"))) ], BinaryAst("+", BinaryAst("+", VarAst("x"), VarAst("y")), VarAst("z"))) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 10.0)) # # the second expression will result an errors, # since x, y, z are bound to the let body # # let (x = 2, y = x + 1, z = x + y) x + y + z; x + y + z ast = ProgAst([ LetAst([ VarDefAst('x', LiteralAst(2)), VarDefAst('y', BinaryAst('+', VarAst('x'), LiteralAst(1))), VarDefAst('z', BinaryAst('+', VarAst('x'), VarAst('y'))) ], BinaryAst('+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'))), BinaryAst('+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z')) ]) with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) ast = IfAst(LiteralAst(""), LiteralAst(1), None) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(2)) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 2.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(False)) evaluate(ast, Environment(), self.assertFalse) ast = {"type": "foo", "value": 'foo'} with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) # fib = λ(n) if n < 2 then n else fib(n - 1) + fib(n - 2); # fib(6); # ast = ProgAst([ AssignAst( VarAst('fib'), LambdaAst( 'n', ['n'], IfAst( BinaryAst('<', VarAst('n'), LiteralAst(2)), VarAst('n'), BinaryAst( '+', CallAst( VarAst('fib'), [BinaryAst('-', VarAst('n'), LiteralAst(1))]), CallAst( VarAst('fib'), [BinaryAst('-', VarAst('n'), LiteralAst(2)) ]))))), CallAst(VarAst('fib'), [LiteralAst(6)]) ]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 8.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(False)) evaluate(ast, Environment(), self.assertFalse) ast = CallAst(LiteralAst(1), []) with self.assertRaises(Exception): evaluate(ast, Environment(), self.assertFalse) code = """ 2 + twice(3, 4) """ global_env = Environment() for name, func in primitive.items(): global_env.define(name, func) parser = Parser(TokenStream(InputStream(code))) evaluate(parser(), global_env, lambda result: result)
#!/usr/bin/env python # encoding: utf-8 import sys from ast import CallAst, VarAst from compiler import to_js from cps_transformer import to_cps from input_stream import InputStream from optimize import Optimizer from parse import Parser from token_stream import TokenStream code = "" for argv in sys.argv[1:]: with open(argv) as file: code += file.read() parser = Parser(TokenStream(InputStream(code))) ast = parser() ast = to_cps(ast, lambda ast: CallAst( VarAst('β_TOPLEVEL'), [ast], )) # print(ast) ast = Optimizer().optimize(ast) # print(ast) js_code = to_js(ast) print(js_code)
def test_evaluate(self): ast = LiteralAst(1.0) environment = Environment() self.assertEqual(evaluate(ast, environment), 1.0) ast = LiteralAst(True) environment = Environment() self.assertEqual(evaluate(ast, environment), True) ast = LiteralAst(False) environment = Environment() self.assertEqual(evaluate(ast, environment), False) ast = LiteralAst("aaa") self.assertEqual(evaluate(ast, Environment()), "aaa") ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) self.assertEqual(evaluate(ast, Environment()), 3.0) ast = ProgAst([]) self.assertEqual(evaluate(ast, Environment()), False) ast = ProgAst([LiteralAst(1)]) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = ProgAst([LiteralAst(1), LiteralAst(2)]) self.assertEqual(evaluate(ast, Environment()), 2.0) ast = AssignAst(LiteralAst(1), LiteralAst("a")) with self.assertRaises(Exception): evaluate(ast, Environment()) ast = ProgAst([AssignAst(VarAst('a'), LiteralAst("foo")), VarAst('a')]) self.assertEqual(evaluate(ast, Environment()), "foo") ast = AssignAst(VarAst("a"), LiteralAst("foo")) with self.assertRaises(Exception): evaluate(ast, Environment(Environment())) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst(1)], ) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst("abc")], ) self.assertEqual(evaluate(ast, Environment()), "abc") # (λ loop (n) if n > 0 then n + loop(n - 1) else 0) (10) ast = CallAst( LambdaAst( "loop", ["n"], IfAst( BinaryAst(">", VarAst("n"), LiteralAst(0)), BinaryAst( "+", VarAst("n"), CallAst(VarAst("loop"), [BinaryAst('-', VarAst('n'), LiteralAst(1))])), LiteralAst(0), ), ), [LiteralAst(10)]) self.assertEqual(evaluate(ast, Environment()), 55.0) # let (x) x; ast = LetAst([VarDefAst("x", None)], VarAst("x")) self.assertEqual(evaluate(ast, Environment()), False) # let (x = 2, y = x + 1, z = x + y) x + y + z ast = LetAst([ VarDefAst("x", LiteralAst(2)), VarDefAst("y", BinaryAst("+", VarAst("x"), LiteralAst(1))), VarDefAst("z", BinaryAst("+", VarAst("x"), VarAst("y"))), ], BinaryAst( "+", BinaryAst("+", VarAst("x"), VarAst("y")), VarAst("z"), )) self.assertEqual(evaluate(ast, Environment()), 10.0) # the second expression will result an errors, # since x, y, z are bound to the let body # let (x = 2, y = x + 1, z = x + y) x + y + z; x + y + z ast = ProgAst([ LetAst( [ VarDefAst('x', LiteralAst(2)), VarDefAst('y', BinaryAst('+', VarAst('x'), LiteralAst(1))), VarDefAst('z', BinaryAst('+', VarAst('x'), VarAst('y'))), ], BinaryAst( '+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'), ), ), BinaryAst( '+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'), ), ]) with self.assertRaises(Exception): evaluate(ast, Environment()) ast = IfAst( LiteralAst(""), LiteralAst(1), None, ) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = IfAst( LiteralAst(False), LiteralAst(1), LiteralAst(2), ) self.assertEqual(evaluate(ast, Environment()), 2.0) ast = IfAst( LiteralAst(False), LiteralAst(1), LiteralAst(False), ) self.assertEqual(evaluate(ast, Environment()), False) ast = {"type": "foo", "value": 'foo'} with self.assertRaises(Exception): evaluate(ast, Environment())
def cps_then_and_else(result: Ast) -> Ast: return CallAst(VarAst(if_continuation), [result])