def test_jump(self): asm = FunctionAssembler('foo', []) label = asm.Label() asm.MOVSD(asm.xmm0, asm.const(42)) asm.JMP(label) asm.MOVSD(asm.xmm0, asm.const(123)) # this is not executed asm.LABEL(label) asm.RET() pyfn = self.load(asm) assert pyfn() == 42
class AstCompiler: def __init__(self, src): self.tree = ast.parse(textwrap.dedent(src)) self.asm = None def show(self, node): import astpretty from ast2png import ast2png astpretty.pprint(node) ast2png(self.tree, highlight_node=node, filename='ast.png') def _newfunc(self, name, argnames): self.asm = FA(name, argnames) self.regs = RegAllocator() for argname in argnames: self.regs.get(argname) self.tmp0 = self.regs.get('__scratch_register_0__') self.tmp1 = self.regs.get('__scratch_register_1__') def compile(self): self.visit(self.tree) assert self.asm is not None, 'No function found?' code = self.asm.assemble_and_relocate() return CompiledFunction(self.asm.nargs, code) def visit(self, node): methname = node.__class__.__name__ meth = getattr(self, methname, None) if meth is None: raise NotImplementedError(methname) return meth(node) def Module(self, node): for child in node.body: self.visit(child) def FunctionDef(self, node): assert not self.asm, 'cannot compile more than one function' argnames = [arg.arg for arg in node.args.args] self._newfunc(node.name, argnames) for child in node.body: self.visit(child) # return 0 by default self.asm.PXOR(self.asm.xmm0, self.asm.xmm0) self.asm.RET() def Pass(self, node): pass def Return(self, node): self.visit(node.value) self.asm.popsd(self.asm.xmm0) self.asm.RET() def Num(self, node): self.asm.MOVSD(self.tmp0, self.asm.const(node.n)) self.asm.pushsd(self.tmp0) def BinOp(self, node): OPS = { 'ADD': self.asm.ADDSD, 'SUB': self.asm.SUBSD, 'MULT': self.asm.MULSD, 'DIV': self.asm.DIVSD, } opname = node.op.__class__.__name__.upper() self.visit(node.left) self.visit(node.right) self.asm.popsd(self.tmp1) self.asm.popsd(self.tmp0) OPS[opname](self.tmp0, self.tmp1) self.asm.pushsd(self.tmp0) def Name(self, node): reg = self.regs.get(node.id) self.asm.pushsd(reg) def Assign(self, node): assert len(node.targets) == 1 varname = node.targets[0].id reg = self.regs.get(varname) self.visit(node.value) self.asm.popsd(self.tmp0) self.asm.MOVSD(reg, self.tmp0) def If(self, node): """ IF <test> GOTO then_label GOTO end_label then_label: <BODY> end_label: ... """ CMP = { 'LT': self.asm.JB } assert not node.orelse then_label = self.asm.Label() end_label = self.asm.Label() op = self.visit(node.test) CMP[op](then_label) self.asm.JMP(end_label) self.asm.LABEL(then_label) for child in node.body: self.visit(child) self.asm.LABEL(end_label) def Compare(self, node): assert len(node.ops) == 1 cmp_op = node.ops[0].__class__.__name__.upper() self.visit(node.left) self.visit(node.comparators[0]) self.asm.popsd(self.tmp1) self.asm.popsd(self.tmp0) self.asm.UCOMISD(self.tmp0, self.tmp1) return cmp_op def While(self, node): """ begin_label: IF <test> GOTO body_label GOTO end_label body_label: <BODY> GOTO begin_label end_label: ... """ CMP = { 'LT': self.asm.JB } begin_label = self.asm.Label() body_label = self.asm.Label() end_label = self.asm.Label() # self.asm.LABEL(begin_label) op = self.visit(node.test) CMP[op](body_label) self.asm.JMP(end_label) self.asm.LABEL(body_label) for child in node.body: self.visit(child) self.asm.JMP(begin_label) self.asm.LABEL(end_label)