def test__cps_atom(self): atom_ast = LiteralAst(1.0) cps_ast = _cps_atom(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast) atom_ast = LiteralAst("abc") cps_ast = _cps_atom(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast) atom_ast = LiteralAst(False) cps_ast = _cps_atom(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast) atom_ast = VarAst('a') cps_ast = _cps_atom(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast)
def test__cps_call(self): call_ast = CallAst(VarAst('foo'), [LiteralAst(1), LiteralAst(2)]) cps_ast = _cps_call(call_ast, lambda x: x) expected_ast = CallAst(VarAst('foo'), [ LambdaAst( '', [cps_ast.args[0].params[0]], VarAst(cps_ast.args[0].params[0]), ), LiteralAst(1), LiteralAst(2) ]) self.assertEqual(cps_ast, expected_ast)
def cps_body(body: List[Ast]) -> Ast: if not body: return k(LiteralAst(False)) if len(body) == 1: return to_cps(body[0], k) return to_cps(body[0], lambda first: ProgAst([first, cps_body(body[1:])]))
def parser() -> Ast: if self._is_punc('('): self._skip_punc('(') exp = self._parse_expression() self._skip_punc(')') return exp if self._is_punc('{'): return self._parse_prog() if self._is_kw('if'): return self._parse_if() if self._is_kw('let'): return self._parse_let() if self._is_kw('true') or self._is_kw('false'): return self._parse_bool() if self._is_kw('lambda'): return self._parse_lambda('lambda') if self._is_kw('λ'): return self._parse_lambda('λ') if self._is_kw('js'): return self._parse_js_raw() token = self._token_stream.next() if token.type in ('str', 'num'): return LiteralAst(token.value) if token.type == 'var': return VarAst(token.value) self.unexpected()
def test__cps_lambda(self): # lambda (x, y) 1 lambda_ast = LambdaAst('', ['x', 'y'], LiteralAst(1)) cps_ast = _cps_lambda(lambda_ast, lambda x: x) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [LiteralAst(1)])) self.assertEqual(cps_ast, expected_ast) # lambda (x, y) x + y lambda_ast = LambdaAst('', ['x', 'y'], BinaryAst('+', VarAst('x'), VarAst('y'))) cps_ast = _cps_lambda(lambda_ast, lambda x: x) # expected result: lambda (continue, args) continue(body) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [BinaryAst('+', VarAst('x'), VarAst('y'))])) self.assertEqual(cps_ast, expected_ast)
def _js_let(ast: LetAst) -> str: if not ast.vardefs: return _to_js(ast.body) # immediately invoked function expression iife = CallAst( LambdaAst('', [ast.vardefs[0].name], LetAst(ast.vardefs[1:], ast.body)), [ast.vardefs[0].define or LiteralAst(False)]) return f'({_to_js(iife)})'
def _optimize_binary_ast(self, ast: BinaryAst) -> Ast: left: Ast = self._optimize_aux(ast.left) right: Ast = self._optimize_aux(ast.right) # if both operands are constants, we can get result # before running program. if isinstance(left, LiteralAst) and isinstance(right, LiteralAst): self.changes += 1 result = apply_op(ast.operator, left.value, right.value) return LiteralAst(result) return BinaryAst(ast.operator, left, right)
def _cps_let(let_ast: LetAst, k: Callable[[Ast], Ast]) -> Ast: if not let_ast.vardefs: return to_cps(let_ast.body, k) return to_cps( CallAst( LambdaAst('', [let_ast.vardefs[0].name], LetAst(let_ast.vardefs[1:], let_ast.body)), [ let_ast.vardefs[0].define if let_ast.vardefs[0].define else LiteralAst(False) ]), k)
def _parse_if(self) -> IfAst: self._skip_kw("if") cond = self._parse_expression() if not self._is_punc('{'): self._skip_kw('then') then = self._parse_expression() else_ = LiteralAst(False) if self._is_kw('else'): self._skip_kw('else') else_ = self._parse_expression() return IfAst(cond, then, else_)
def test__cps_let(self): # let_ast = LetAst( # [VarDefAst('a', LiteralAst(1)), VarDefAst('b', LiteralAst("a"))], # LiteralAst(False)) let_ast = LetAst( [], LiteralAst(False), ) cps_ast = _cps_let(let_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) let_ast = LetAst([VarDefAst('a', LiteralAst(1))], VarAst('a')) cps_ast = _cps_let(let_ast, lambda x: x) self.assertEqual( cps_ast, CallAst( LambdaAst( '', [cps_ast.func.params[0], 'a'], CallAst(VarAst(cps_ast.func.params[0]), [VarAst('a')])), [ LambdaAst('', [cps_ast.args[0].params[0]], VarAst(cps_ast.args[0].params[0])), LiteralAst(1) ]))
def _optimize_prog_ast(self, ast: ProgAst) -> Ast: prog = ast.prog if not prog: self.changes += 1 return LiteralAst(False) if len(prog) == 1: # self.changes += 1 return self._optimize_aux(prog[0]) if not has_side_effect(prog[0]): self.changes += 1 return self._optimize_aux(ProgAst(prog[1:])) return ProgAst([ self._optimize_aux(prog[0]), self._optimize_aux(ProgAst(prog[1:])) ])
def _optimize_if_ast(self, ast: IfAst) -> Ast: cond = self._optimize_aux(ast.cond) then = self._optimize_aux(ast.then) else_ = self._optimize_aux(ast.else_) # if cond is constant, we can decide which branch to execute # before running program. if isinstance(cond, LiteralAst): self.changes += 1 return then if cond else else_ if isinstance(cond, VarAst) and _is_constant_var(cond): # For lambda function params, we don't know its current value, # so its current value is assigned None if cond.define.current_value == LiteralAst(False): self.changes += 1 return else_ if isinstance(cond.define.current_value, (LiteralAst, LambdaAst)): self.changes += 1 return then return IfAst(cond, then, else_)
def parser() -> Ast: ast = self._maybe_binary(self._parse_atom(), 0) # relational operator short-circuit implementation if isinstance(ast, BinaryAst): # left || right -> (lambda (left) { # if left then left else right})(left) binary_ast = cast(BinaryAst, ast) if binary_ast.operator == '||': iife_param = gensym('left') ast = CallAst( LambdaAst( '', [iife_param], IfAst( VarAst(iife_param), VarAst(iife_param), binary_ast.right)), [binary_ast.left]) elif binary_ast.operator == '&&': ast = IfAst(binary_ast.left, binary_ast.right, LiteralAst(False)) return ast
def _parse_let(self) -> Union[CallAst, LetAst]: """ When it is a named let, if an arg is not followed by some expression, then a false value is assigned to the arg by the parser. :return: """ self._skip_kw('let') if self._token_stream.peek().type == 'var': name = self._token_stream.next().value vardefs = self._delimited('(', ')', ',', self._parse_vardef) varnames = [vardef.name for vardef in vardefs] defines = [vardef.define if vardef.define else LiteralAst(False) for vardef in vardefs] return CallAst( LambdaAst( name, varnames, self._parse_expression()), defines) vardefs = self._delimited('(', ')', ',', self._parse_vardef) body = self._parse_expression() return LetAst(vardefs, body)
def _unwrap_iife(self, iife_ast: CallAst) -> ProgAst: """ unwrap iife, assign iife func params to iife args, then execute iife func body in closure. :param iife_ast: :return: """ assert isinstance(iife_ast.func, LambdaAst) assert isinstance(self.closure, LambdaAst) def rename_iife_param(param: str) -> str: """ If iife param collides with closure, rename it. :param param: :return: """ self.closure = cast(LambdaAst, self.closure) env: Environment = self.closure.env var_define: VarDefine = iife_func.env.get(param) if param in env.vars: param = gensym(param + '$') self.closure.iife_params.append(param) env.define(param, True) # change all references to iife params since we may change param names for ref in var_define.refs: ref.name = param return param iife_func = iife_ast.func iife_args = iife_ast.args assert len(iife_func.params) >= len(iife_args) prog: List[Ast] = [ AssignAst(VarAst(rename_iife_param(param)), arg) for param, arg in zip_longest( iife_func.params, iife_args, fillvalue=LiteralAst(False)) ] prog.append(self._optimize_aux(iife_func.body)) return ProgAst(prog)
def test_evaluate(self): ast = LiteralAst(1.0) environment = Environment() evaluate(ast, environment, lambda value: self.assertEqual(value, 1.0)) ast = LiteralAst(True) environment = Environment() evaluate(ast, environment, self.assertTrue) ast = LiteralAst(False) environment = Environment() evaluate(ast, environment, self.assertFalse) ast = LiteralAst("aaa") evaluate(ast, Environment(), lambda value: self.assertEqual(value, "aaa")) ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 3.0)) ast = ProgAst([]) evaluate(ast, Environment(), self.assertFalse) ast = ProgAst([LiteralAst(1)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = ProgAst([LiteralAst(1), LiteralAst(2)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 2.0)) ast = AssignAst(LiteralAst(1), LiteralAst("a")) with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) ast = ProgAst([AssignAst(VarAst('a'), LiteralAst("foo")), VarAst('a')]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, "foo")) ast = AssignAst(VarAst("a"), LiteralAst("foo")) with self.assertRaises(Exception): evaluate(ast, Environment(Environment()), lambda value: value) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst(1)], ) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = CallAst(LambdaAst("", ["a"], VarAst("a")), [LiteralAst("abc")]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, "abc")) # # (λ loop (n) if n > 0 then n + loop(n - 1) else 0) (10) ast = CallAst( LambdaAst( "loop", ["n"], IfAst( BinaryAst(">", VarAst("n"), LiteralAst(0)), BinaryAst( "+", VarAst("n"), CallAst(VarAst("loop"), [BinaryAst('-', VarAst('n'), LiteralAst(1))])), LiteralAst(0))), [LiteralAst(10)]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 55.0)) # # let (x) x; ast = LetAst([VarDefAst("x", None)], VarAst("x")) evaluate(ast, Environment(), self.assertFalse) # # let (x = 2, y = x + 1, z = x + y) x + y + z ast = LetAst([ VarDefAst("x", LiteralAst(2)), VarDefAst("y", BinaryAst("+", VarAst("x"), LiteralAst(1))), VarDefAst("z", BinaryAst("+", VarAst("x"), VarAst("y"))) ], BinaryAst("+", BinaryAst("+", VarAst("x"), VarAst("y")), VarAst("z"))) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 10.0)) # # the second expression will result an errors, # since x, y, z are bound to the let body # # let (x = 2, y = x + 1, z = x + y) x + y + z; x + y + z ast = ProgAst([ LetAst([ VarDefAst('x', LiteralAst(2)), VarDefAst('y', BinaryAst('+', VarAst('x'), LiteralAst(1))), VarDefAst('z', BinaryAst('+', VarAst('x'), VarAst('y'))) ], BinaryAst('+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'))), BinaryAst('+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z')) ]) with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) ast = IfAst(LiteralAst(""), LiteralAst(1), None) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 1.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(2)) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 2.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(False)) evaluate(ast, Environment(), self.assertFalse) ast = {"type": "foo", "value": 'foo'} with self.assertRaises(Exception): evaluate(ast, Environment(), lambda value: value) # fib = λ(n) if n < 2 then n else fib(n - 1) + fib(n - 2); # fib(6); # ast = ProgAst([ AssignAst( VarAst('fib'), LambdaAst( 'n', ['n'], IfAst( BinaryAst('<', VarAst('n'), LiteralAst(2)), VarAst('n'), BinaryAst( '+', CallAst( VarAst('fib'), [BinaryAst('-', VarAst('n'), LiteralAst(1))]), CallAst( VarAst('fib'), [BinaryAst('-', VarAst('n'), LiteralAst(2)) ]))))), CallAst(VarAst('fib'), [LiteralAst(6)]) ]) evaluate(ast, Environment(), lambda value: self.assertEqual(value, 8.0)) ast = IfAst(LiteralAst(False), LiteralAst(1), LiteralAst(False)) evaluate(ast, Environment(), self.assertFalse) ast = CallAst(LiteralAst(1), []) with self.assertRaises(Exception): evaluate(ast, Environment(), self.assertFalse) code = """ 2 + twice(3, 4) """ global_env = Environment() for name, func in primitive.items(): global_env.define(name, func) parser = Parser(TokenStream(InputStream(code))) evaluate(parser(), global_env, lambda result: result)
def _parse_bool(self) -> LiteralAst: token = self._token_stream.next() assert token.type == 'kw' return LiteralAst(token.value == 'true')
def test__cps_if(self): if_ast = IfAst(LiteralAst(1), LiteralAst(2), LiteralAst(3)) cps_ast = _cps_if(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst(LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(3)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast) if_ast = IfAst( LiteralAst(1), LiteralAst(2), LiteralAst(False), ) cps_ast = _cps_if(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst( LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(False)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast)
def test_evaluate(self): ast = LiteralAst(1.0) environment = Environment() self.assertEqual(evaluate(ast, environment), 1.0) ast = LiteralAst(True) environment = Environment() self.assertEqual(evaluate(ast, environment), True) ast = LiteralAst(False) environment = Environment() self.assertEqual(evaluate(ast, environment), False) ast = LiteralAst("aaa") self.assertEqual(evaluate(ast, Environment()), "aaa") ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) self.assertEqual(evaluate(ast, Environment()), 3.0) ast = ProgAst([]) self.assertEqual(evaluate(ast, Environment()), False) ast = ProgAst([LiteralAst(1)]) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = ProgAst([LiteralAst(1), LiteralAst(2)]) self.assertEqual(evaluate(ast, Environment()), 2.0) ast = AssignAst(LiteralAst(1), LiteralAst("a")) with self.assertRaises(Exception): evaluate(ast, Environment()) ast = ProgAst([AssignAst(VarAst('a'), LiteralAst("foo")), VarAst('a')]) self.assertEqual(evaluate(ast, Environment()), "foo") ast = AssignAst(VarAst("a"), LiteralAst("foo")) with self.assertRaises(Exception): evaluate(ast, Environment(Environment())) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst(1)], ) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = CallAst( LambdaAst("", ["a"], VarAst("a")), [LiteralAst("abc")], ) self.assertEqual(evaluate(ast, Environment()), "abc") # (λ loop (n) if n > 0 then n + loop(n - 1) else 0) (10) ast = CallAst( LambdaAst( "loop", ["n"], IfAst( BinaryAst(">", VarAst("n"), LiteralAst(0)), BinaryAst( "+", VarAst("n"), CallAst(VarAst("loop"), [BinaryAst('-', VarAst('n'), LiteralAst(1))])), LiteralAst(0), ), ), [LiteralAst(10)]) self.assertEqual(evaluate(ast, Environment()), 55.0) # let (x) x; ast = LetAst([VarDefAst("x", None)], VarAst("x")) self.assertEqual(evaluate(ast, Environment()), False) # let (x = 2, y = x + 1, z = x + y) x + y + z ast = LetAst([ VarDefAst("x", LiteralAst(2)), VarDefAst("y", BinaryAst("+", VarAst("x"), LiteralAst(1))), VarDefAst("z", BinaryAst("+", VarAst("x"), VarAst("y"))), ], BinaryAst( "+", BinaryAst("+", VarAst("x"), VarAst("y")), VarAst("z"), )) self.assertEqual(evaluate(ast, Environment()), 10.0) # the second expression will result an errors, # since x, y, z are bound to the let body # let (x = 2, y = x + 1, z = x + y) x + y + z; x + y + z ast = ProgAst([ LetAst( [ VarDefAst('x', LiteralAst(2)), VarDefAst('y', BinaryAst('+', VarAst('x'), LiteralAst(1))), VarDefAst('z', BinaryAst('+', VarAst('x'), VarAst('y'))), ], BinaryAst( '+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'), ), ), BinaryAst( '+', BinaryAst('+', VarAst('x'), VarAst('y')), VarAst('z'), ), ]) with self.assertRaises(Exception): evaluate(ast, Environment()) ast = IfAst( LiteralAst(""), LiteralAst(1), None, ) self.assertEqual(evaluate(ast, Environment()), 1.0) ast = IfAst( LiteralAst(False), LiteralAst(1), LiteralAst(2), ) self.assertEqual(evaluate(ast, Environment()), 2.0) ast = IfAst( LiteralAst(False), LiteralAst(1), LiteralAst(False), ) self.assertEqual(evaluate(ast, Environment()), False) ast = {"type": "foo", "value": 'foo'} with self.assertRaises(Exception): evaluate(ast, Environment())
def test_to_cps(self): js_raw_ast = JsAst("aa") cps_ast = _cps_js_raw(js_raw_ast, lambda x: x) self.assertEqual(cps_ast, js_raw_ast) atom_ast = LiteralAst(1.0) cps_ast = to_cps(atom_ast, lambda x: x) self.assertEqual(atom_ast, cps_ast) let_ast = LetAst([], LiteralAst(False)) cps_ast = to_cps(let_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) prog_ast = ProgAst([]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) prog_ast = ProgAst([LiteralAst(1)]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(1)) prog_ast = ProgAst([LiteralAst(1), LiteralAst(2)]) cps_ast = to_cps(prog_ast, lambda x: x) self.assertEqual(cps_ast, ProgAst([LiteralAst(1), LiteralAst(2)])) if_ast = IfAst(LiteralAst(1), LiteralAst(2), LiteralAst(3)) cps_ast: CallAst = to_cps(if_ast, lambda x: x) expected_ast = CallAst( LambdaAst( '', cps_ast.func.params, IfAst(LiteralAst(1), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(2)]), CallAst(VarAst(cps_ast.func.params[0]), [LiteralAst(3)]))), [ LambdaAst('', cps_ast.args[0].params, VarAst(cps_ast.args[0].params[0])) ]) self.assertEqual(cps_ast, expected_ast) lambda_ast = LambdaAst('', ['x', 'y'], LiteralAst(1)) cps_ast = to_cps(lambda_ast, lambda x: x) expected_ast = LambdaAst( '', [cps_ast.params[0]] + ['x', 'y'], CallAst(VarAst(cps_ast.params[0]), [LiteralAst(1)])) self.assertEqual(cps_ast, expected_ast) binary_ast = BinaryAst('+', LiteralAst(1), LiteralAst(2)) cps_ast = to_cps(binary_ast, lambda x: x) self.assertEqual(cps_ast, binary_ast) parse = Parser(TokenStream(InputStream("a = foo(10);"))) cps_ast = to_cps(parse(), lambda x: x) expected_ast = CallAst(VarAst('foo'), [ LambdaAst( '', [cps_ast.args[0].params[0]], AssignAst(VarAst('a'), VarAst(cps_ast.args[0].params[0]))), LiteralAst(10) ]) self.assertEqual(cps_ast, expected_ast)
def test__cps_prog(self): prog_ast = ProgAst([]) cps_ast = _cps_prog(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(False)) prog_ast = ProgAst([LiteralAst(1)]) cps_ast = _cps_prog(prog_ast, lambda x: x) self.assertEqual(cps_ast, LiteralAst(1)) prog_ast = ProgAst([LiteralAst(1), LiteralAst(2)]) cps_ast = _cps_prog(prog_ast, lambda x: x) self.assertEqual(cps_ast, ProgAst([ LiteralAst(1), LiteralAst(2), ])) prog_ast = ProgAst([LiteralAst(1), LiteralAst(2), LiteralAst(3)]) cps_ast = _cps_prog(prog_ast, lambda x: x) self.assertEqual( cps_ast, ProgAst([LiteralAst(1), ProgAst([LiteralAst(2), LiteralAst(3)])]))