class TestGenerateFunction(unittest.TestCase): def setUp(self) -> None: self.var = Var('arg') self.arg = RuntimeArg('arg', int_rprimitive) self.env = Environment() self.reg = self.env.add_local(self.var, int_rprimitive) self.block = BasicBlock(0) def test_simple(self) -> None: self.block.ops.append(Return(self.reg)) fn = FuncIR( FuncDecl('myfunc', None, 'mod', FuncSignature([self.arg], int_rprimitive)), [self.block], self.env) emitter = Emitter(EmitterContext(NameGenerator([['mod']]))) generate_native_function(fn, emitter, 'prog.py', 'prog') result = emitter.fragments assert_string_arrays_equal([ 'CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', 'CPyL0: ;\n', ' return cpy_r_arg;\n', '}\n', ], result, msg='Generated code invalid') def test_register(self) -> None: self.env.temp_index = 0 op = LoadInt(5) self.block.ops.append(op) self.env.add_op(op) fn = FuncIR( FuncDecl('myfunc', None, 'mod', FuncSignature([self.arg], list_rprimitive)), [self.block], self.env) emitter = Emitter(EmitterContext(NameGenerator([['mod']]))) generate_native_function(fn, emitter, 'prog.py', 'prog') result = emitter.fragments assert_string_arrays_equal([ 'PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', ' CPyTagged cpy_r_r0;\n', 'CPyL0: ;\n', ' cpy_r_r0 = 10;\n', '}\n', ], result, msg='Generated code invalid')
class TestEmitter(unittest.TestCase): def setUp(self) -> None: self.env = Environment() self.n = self.env.add_local(Var('n'), int_rprimitive) self.context = EmitterContext(NameGenerator([['mod']])) self.emitter = Emitter(self.context, self.env) def test_label(self) -> None: assert self.emitter.label(BasicBlock(4)) == 'CPyL4' def test_reg(self) -> None: assert self.emitter.reg(self.n) == 'cpy_r_n' def test_emit_line(self) -> None: self.emitter.emit_line('line;') self.emitter.emit_line('a {') self.emitter.emit_line('f();') self.emitter.emit_line('}') assert self.emitter.fragments == [ 'line;\n', 'a {\n', ' f();\n', '}\n' ]
class TestEmitter(unittest.TestCase): def setUp(self) -> None: self.env = Environment() self.n = self.env.add_local(Var('n'), IntRType()) self.context = EmitterContext() self.emitter = Emitter(self.context, self.env) def test_label(self) -> None: assert self.emitter.label(Label(4)) == 'CPyL4' def test_reg(self) -> None: assert self.emitter.reg(self.n) == 'cpy_r_n' def test_emit_line(self) -> None: self.emitter.emit_line('line;') self.emitter.emit_line('a {') self.emitter.emit_line('f();') self.emitter.emit_line('}') assert self.emitter.fragments == [ 'line;\n', 'a {\n', ' f();\n', '}\n' ]
class TestGenerateFunction(unittest.TestCase): def setUp(self) -> None: self.var = Var('arg') self.arg = RuntimeArg('arg', IntRType()) self.env = Environment() self.reg = self.env.add_local(self.var, IntRType()) self.block = BasicBlock(Label(0)) def test_simple(self) -> None: self.block.ops.append(Return(self.reg)) fn = FuncIR('myfunc', [self.arg], IntRType(), [self.block], self.env) emitter = Emitter(EmitterContext()) generate_native_function(fn, emitter) result = emitter.fragments assert_string_arrays_equal([ 'static CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', 'CPyL0: ;\n', ' return cpy_r_arg;\n', '}\n', ], result, msg='Generated code invalid') def test_register(self) -> None: self.temp = self.env.add_temp(IntRType()) self.block.ops.append(LoadInt(self.temp, 5)) fn = FuncIR('myfunc', [self.arg], ListRType(), [self.block], self.env) emitter = Emitter(EmitterContext()) generate_native_function(fn, emitter) result = emitter.fragments assert_string_arrays_equal([ 'static PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', ' CPyTagged cpy_r_r0;\n', 'CPyL0: ;\n', ' cpy_r_r0 = 10;\n', '}\n', ], result, msg='Generated code invalid')
class TestFunctionEmitterVisitor(unittest.TestCase): def setUp(self) -> None: self.env = Environment() self.n = self.env.add_local(Var('n'), int_rprimitive) self.m = self.env.add_local(Var('m'), int_rprimitive) self.k = self.env.add_local(Var('k'), int_rprimitive) self.l = self.env.add_local(Var('l'), list_rprimitive) # noqa self.ll = self.env.add_local(Var('ll'), list_rprimitive) self.o = self.env.add_local(Var('o'), object_rprimitive) self.o2 = self.env.add_local(Var('o2'), object_rprimitive) self.d = self.env.add_local(Var('d'), dict_rprimitive) self.b = self.env.add_local(Var('b'), bool_rprimitive) self.t = self.env.add_local(Var('t'), RTuple([int_rprimitive, bool_rprimitive])) self.tt = self.env.add_local( Var('tt'), RTuple( [RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive])) ir = ClassIR('A', 'mod') ir.attributes = OrderedDict([('x', bool_rprimitive), ('y', int_rprimitive)]) compute_vtable(ir) ir.mro = [ir] self.r = self.env.add_local(Var('r'), RInstance(ir)) self.context = EmitterContext(NameGenerator([['mod']])) self.emitter = Emitter(self.context, self.env) self.declarations = Emitter(self.context, self.env) self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations, 'prog.py', 'prog') def test_goto(self) -> None: self.assert_emit(Goto(BasicBlock(2)), "goto CPyL2;") def test_return(self) -> None: self.assert_emit(Return(self.m), "return cpy_r_m;") def test_load_int(self) -> None: self.assert_emit(LoadInt(5), "cpy_r_r0 = 10;") def test_tuple_get(self) -> None: self.assert_emit(TupleGet(self.t, 1, 0), 'cpy_r_r0 = cpy_r_t.f1;') def test_load_None(self) -> None: self.assert_emit(PrimitiveOp([], none_object_op, 0), "cpy_r_r0 = Py_None;") def test_load_True(self) -> None: self.assert_emit(PrimitiveOp([], true_op, 0), "cpy_r_r0 = 1;") def test_load_False(self) -> None: self.assert_emit(PrimitiveOp([], false_op, 0), "cpy_r_r0 = 0;") def test_assign_int(self) -> None: self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;") def test_int_add(self) -> None: self.assert_emit_binary_op( '+', self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);") def test_int_sub(self) -> None: self.assert_emit_binary_op( '-', self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);") def test_list_repeat(self) -> None: self.assert_emit_binary_op( '*', self.ll, self.l, self.n, """Py_ssize_t __tmp1; __tmp1 = CPyTagged_AsSsize_t(cpy_r_n); if (__tmp1 == -1 && PyErr_Occurred()) CPyError_OutOfMemory(); cpy_r_r0 = PySequence_Repeat(cpy_r_l, __tmp1); """) def test_int_neg(self) -> None: self.assert_emit(PrimitiveOp([self.m], int_neg_op, 55), "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);") def test_list_len(self) -> None: self.assert_emit( PrimitiveOp([self.l], list_len_op, 55), """Py_ssize_t __tmp1; __tmp1 = PyList_GET_SIZE(cpy_r_l); cpy_r_r0 = CPyTagged_ShortFromSsize_t(__tmp1); """) def test_branch(self) -> None: self.assert_emit( Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR), """if (cpy_r_b) { goto CPyL8; } else goto CPyL9; """) b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR) b.negated = True self.assert_emit( b, """if (!cpy_r_b) { goto CPyL8; } else goto CPyL9; """) def test_call(self) -> None: decl = FuncDecl( 'myfn', None, 'mod', FuncSignature([RuntimeArg('m', int_rprimitive)], int_rprimitive)) self.assert_emit(Call(decl, [self.m], 55), "cpy_r_r0 = CPyDef_myfn(cpy_r_m);") def test_call_two_args(self) -> None: decl = FuncDecl( 'myfn', None, 'mod', FuncSignature([ RuntimeArg('m', int_rprimitive), RuntimeArg('n', int_rprimitive) ], int_rprimitive)) self.assert_emit(Call(decl, [self.m, self.k], 55), "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);") def test_inc_ref(self) -> None: self.assert_emit(IncRef(self.m), "CPyTagged_IncRef(cpy_r_m);") def test_dec_ref(self) -> None: self.assert_emit(DecRef(self.m), "CPyTagged_DecRef(cpy_r_m);") def test_dec_ref_tuple(self) -> None: self.assert_emit(DecRef(self.t), 'CPyTagged_DecRef(cpy_r_t.f0);') def test_dec_ref_tuple_nested(self) -> None: self.assert_emit(DecRef(self.tt), 'CPyTagged_DecRef(cpy_r_tt.f0.f0);') def test_list_get_item(self) -> None: self.assert_emit(PrimitiveOp([self.m, self.k], list_get_item_op, 55), """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""") def test_list_set_item(self) -> None: self.assert_emit( PrimitiveOp([self.l, self.n, self.o], list_set_item_op, 55), """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""") def test_box(self) -> None: self.assert_emit(Box(self.n), """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""") def test_unbox(self) -> None: self.assert_emit( Unbox(self.m, int_rprimitive, 55), """if (likely(PyLong_Check(cpy_r_m))) cpy_r_r0 = CPyTagged_FromObject(cpy_r_m); else { CPy_TypeError("int", cpy_r_m); cpy_r_r0 = CPY_INT_TAG; } """) def test_new_list(self) -> None: self.assert_emit( PrimitiveOp([self.n, self.m], new_list_op, 55), """cpy_r_r0 = PyList_New(2); if (likely(cpy_r_r0 != NULL)) { PyList_SET_ITEM(cpy_r_r0, 0, cpy_r_n); PyList_SET_ITEM(cpy_r_r0, 1, cpy_r_m); } """) def test_list_append(self) -> None: self.assert_emit( PrimitiveOp([self.l, self.o], list_append_op, 1), """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o) >= 0;""") def test_get_attr(self) -> None: self.assert_emit( GetAttr(self.r, 'y', 1), """cpy_r_r0 = native_A_gety((AObject *)cpy_r_r); /* y */""") def test_set_attr(self) -> None: self.assert_emit( SetAttr(self.r, 'y', self.m, 1), "cpy_r_r0 = native_A_sety((AObject *)cpy_r_r, cpy_r_m); /* y */") def test_dict_get_item(self) -> None: self.assert_emit(PrimitiveOp([self.d, self.o2], dict_get_item_op, 1), """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""") def test_dict_set_item(self) -> None: self.assert_emit( PrimitiveOp([self.d, self.o, self.o2], dict_set_item_op, 1), """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2) >= 0;""") def test_dict_update(self) -> None: self.assert_emit( PrimitiveOp([self.d, self.o], dict_update_op, 1), """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o) >= 0;""") def test_new_dict(self) -> None: self.assert_emit(PrimitiveOp([], new_dict_op, 1), """cpy_r_r0 = PyDict_New();""") def test_dict_contains(self) -> None: self.assert_emit_binary_op( 'in', self.b, self.o, self.d, """int __tmp1 = PyDict_Contains(cpy_r_d, cpy_r_o); if (__tmp1 < 0) cpy_r_r0 = 2; else cpy_r_r0 = __tmp1; """) def assert_emit(self, op: Op, expected: str) -> None: self.emitter.fragments = [] self.declarations.fragments = [] self.env.temp_index = 0 if isinstance(op, RegisterOp): self.env.add_op(op) op.accept(self.visitor) frags = self.declarations.fragments + self.emitter.fragments actual_lines = [line.strip(' ') for line in frags] assert all(line.endswith('\n') for line in actual_lines) actual_lines = [line.rstrip('\n') for line in actual_lines] expected_lines = expected.rstrip().split('\n') expected_lines = [line.strip(' ') for line in expected_lines] assert_string_arrays_equal(expected_lines, actual_lines, msg='Generated code unexpected') def assert_emit_binary_op(self, op: str, dest: Value, left: Value, right: Value, expected: str) -> None: ops = binary_ops[op] for desc in ops: if (is_subtype(left.type, desc.arg_types[0]) and is_subtype(right.type, desc.arg_types[1])): self.assert_emit(PrimitiveOp([left, right], desc, 55), expected) break else: assert False, 'Could not find matching op'
class IRBuilder(NodeVisitor[Register]): def __init__(self, types: Dict[Expression, Type], mapper: Mapper) -> None: self.types = types self.environment = Environment() self.environments = [self.environment] self.blocks = [] # type: List[List[BasicBlock]] self.functions = [] # type: List[FuncIR] self.classes = [] # type: List[ClassIR] self.targets = [] # type: List[Register] # These lists operate as stack frames for loops. Each loop adds a new # frame (i.e. adds a new empty list [] to the outermost list). Each # break or continue is inserted within that frame as they are visited # and at the end of the loop the stack is popped and any break/continue # gotos have their targets rewritten to the next basic block. self.break_gotos = [] # type: List[List[Goto]] self.continue_gotos = [] # type: List[List[Goto]] self.mapper = mapper self.imports = [] # type: List[str] self.current_module_name = None # type: Optional[str] def visit_mypy_file(self, mypyfile: MypyFile) -> Register: if mypyfile.fullname() in ('typing', 'abc'): # These module are special; their contents are currently all # built-in primitives. return INVALID_REGISTER # First pass: Build ClassIRs and TypeInfo-to-ClassIR mapping. for node in mypyfile.defs: if isinstance(node, ClassDef): self.prepare_class_def(node) # Second pass: Generate ops. self.current_module_name = mypyfile.fullname() for node in mypyfile.defs: node.accept(self) return INVALID_REGISTER def prepare_class_def(self, cdef: ClassDef) -> None: ir = ClassIR(cdef.name, []) # Populate attributes later in visit_class_def self.classes.append(ir) self.mapper.type_to_ir[cdef.info] = ir def visit_class_def(self, cdef: ClassDef) -> Register: attributes = [] for name, node in cdef.info.names.items(): if isinstance(node.node, Var): attributes.append((name, self.type_to_rtype(node.node.type))) ir = self.mapper.type_to_ir[cdef.info] ir.attributes = attributes return INVALID_REGISTER def visit_import(self, node: Import) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" for node_id, _ in node.ids: self.imports.append(node_id) return INVALID_REGISTER def visit_import_from(self, node: ImportFrom) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" self.imports.append(node.id) return INVALID_REGISTER def visit_import_all(self, node: ImportAll) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" self.imports.append(node.id) return INVALID_REGISTER def visit_func_def(self, fdef: FuncDef) -> Register: self.enter() for arg in fdef.arguments: self.environment.add_local(arg.variable, self.type_to_rtype(arg.variable.type)) fdef.body.accept(self) ret_type = self.convert_return_type(fdef) if ret_type.name == 'None': self.add_implicit_return() else: self.add_implicit_unreachable() blocks, env = self.leave() args = self.convert_args(fdef) func = FuncIR(fdef.name(), args, ret_type, blocks, env) self.functions.append(func) return INVALID_REGISTER def convert_args(self, fdef: FuncDef) -> List[RuntimeArg]: assert isinstance(fdef.type, CallableType) ann = fdef.type return [ RuntimeArg(arg.variable.name(), self.type_to_rtype(ann.arg_types[i])) for i, arg in enumerate(fdef.arguments) ] def convert_return_type(self, fdef: FuncDef) -> RType: assert isinstance(fdef.type, CallableType) return self.type_to_rtype(fdef.type.ret_type) def add_implicit_return(self) -> None: block = self.blocks[-1][-1] if not block.ops or not isinstance(block.ops[-1], Return): retval = self.environment.add_temp(NoneRType()) self.add(PrimitiveOp(retval, PrimitiveOp.NONE)) self.add(Return(retval)) def add_implicit_unreachable(self) -> None: block = self.blocks[-1][-1] if not block.ops or not isinstance(block.ops[-1], Return): self.add(Unreachable()) def visit_block(self, block: Block) -> Register: for stmt in block.body: stmt.accept(self) return INVALID_REGISTER def visit_expression_stmt(self, stmt: ExpressionStmt) -> Register: self.accept(stmt.expr) return INVALID_REGISTER def visit_return_stmt(self, stmt: ReturnStmt) -> Register: if stmt.expr: retval = self.accept(stmt.expr) else: retval = self.environment.add_temp(NoneRType()) self.add(PrimitiveOp(retval, PrimitiveOp.NONE)) self.add(Return(retval)) return INVALID_REGISTER def visit_assignment_stmt(self, stmt: AssignmentStmt) -> Register: assert len(stmt.lvalues) == 1 lvalue = stmt.lvalues[0] if stmt.type: lvalue_type = self.type_to_rtype(stmt.type) else: if isinstance(lvalue, IndexExpr): # TODO: This won't be right for user-defined classes. Store the # lvalue type in mypy and remove this special case. lvalue_type = ObjectRType() else: lvalue_type = self.node_type(lvalue) rvalue_type = self.node_type(stmt.rvalue) return self.assign(lvalue, stmt.rvalue, rvalue_type, lvalue_type, declare_new=(stmt.type is not None)) def visit_operator_assignment_stmt( self, stmt: OperatorAssignmentStmt) -> Register: target = self.get_assignment_target(stmt.lvalue, declare_new=False) if isinstance(target, AssignmentTargetRegister): ltype = self.environment.types[target.register] rtype = self.node_type(stmt.rvalue) rreg = self.accept(stmt.rvalue) return self.binary_op(ltype, target.register, rtype, rreg, stmt.op, target=target.register) # NOTE: List index not supported yet for compound assignments. assert False, 'Unsupported lvalue: %r' def get_assignment_target(self, lvalue: Lvalue, declare_new: bool) -> AssignmentTarget: if isinstance(lvalue, NameExpr): # Assign to local variable. assert lvalue.kind == LDEF if lvalue.is_def or declare_new: # Define a new variable. assert isinstance(lvalue.node, Var) # TODO: Can this fail? lvalue_num = self.environment.add_local( lvalue.node, self.node_type(lvalue)) else: # Assign to a previously defined variable. assert isinstance(lvalue.node, Var) # TODO: Can this fail? lvalue_num = self.environment.lookup(lvalue.node) return AssignmentTargetRegister(lvalue_num) elif isinstance(lvalue, IndexExpr): # Indexed assignment x[y] = e base_type = self.node_type(lvalue.base) index_type = self.node_type(lvalue.index) base_reg = self.accept(lvalue.base) index_reg = self.accept(lvalue.index) if isinstance(base_type, ListRType) and isinstance( index_type, IntRType): # Indexed list set return AssignmentTargetIndex(base_reg, index_reg, base_type) elif isinstance(base_type, DictRType): # Indexed dict set boxed_index = self.box(index_reg, index_type) return AssignmentTargetIndex(base_reg, boxed_index, base_type) elif isinstance(lvalue, MemberExpr): # Attribute assignment x.y = e obj_type = self.node_type(lvalue.expr) assert isinstance( obj_type, UserRType), 'Attribute set only supported for user types' obj_reg = self.accept(lvalue.expr) return AssignmentTargetAttr(obj_reg, lvalue.name, obj_type) assert False, 'Unsupported lvalue: %r' % lvalue def assign_to_target(self, target: AssignmentTarget, rvalue: Expression, rvalue_type: RType, needs_box: bool) -> Register: rvalue_type = rvalue_type or self.node_type(rvalue) if isinstance(target, AssignmentTargetRegister): if needs_box: unboxed = self.accept(rvalue) return self.box(unboxed, rvalue_type, target=target.register) else: return self.accept(rvalue, target=target.register) elif isinstance(target, AssignmentTargetAttr): rvalue_reg = self.accept(rvalue) if needs_box: rvalue_reg = self.box(rvalue_reg, rvalue_type) self.add( SetAttr(target.obj_reg, target.attr, rvalue_reg, target.obj_type)) return INVALID_REGISTER elif isinstance(target, AssignmentTargetIndex): item_reg = self.accept(rvalue) boxed_item_reg = self.box(item_reg, rvalue_type) if isinstance(target.rtype, ListRType): op = PrimitiveOp.LIST_SET elif isinstance(target.rtype, DictRType): op = PrimitiveOp.DICT_SET else: assert False, target.rtype self.add( PrimitiveOp(None, op, target.base_reg, target.index_reg, boxed_item_reg)) return INVALID_REGISTER assert False, 'Unsupported assignment target' def assign(self, lvalue: Lvalue, rvalue: Expression, rvalue_type: RType, lvalue_type: RType, declare_new: bool) -> Register: target = self.get_assignment_target(lvalue, declare_new) needs_box = rvalue_type.supports_unbox and not lvalue_type.supports_unbox return self.assign_to_target(target, rvalue, rvalue_type, needs_box) def visit_if_stmt(self, stmt: IfStmt) -> Register: # If statements are normalized assert len(stmt.expr) == 1 branches = self.process_conditional(stmt.expr[0]) if_body = self.new_block() self.set_branches(branches, True, if_body) stmt.body[0].accept(self) if_leave = self.add_leave() if stmt.else_body: else_body = self.new_block() self.set_branches(branches, False, else_body) stmt.else_body.accept(self) else_leave = self.add_leave() next = self.new_block() if else_leave: else_leave.label = next.label else: # No else block. next = self.new_block() self.set_branches(branches, False, next) if if_leave: if_leave.label = next.label return INVALID_REGISTER def add_leave(self) -> Optional[Goto]: if not self.blocks[-1][-1].ops or not isinstance( self.blocks[-1][-1].ops[-1], Return): leave = Goto(INVALID_LABEL) self.add(leave) return leave return None def push_loop_stack(self) -> None: self.break_gotos.append([]) self.continue_gotos.append([]) def pop_loop_stack(self, continue_block: BasicBlock, break_block: BasicBlock) -> None: for continue_goto in self.continue_gotos.pop(): continue_goto.label = continue_block.label for break_goto in self.break_gotos.pop(): break_goto.label = break_block.label def visit_while_stmt(self, s: WhileStmt) -> Register: self.push_loop_stack() # Split block so that we get a handle to the top of the loop. goto = Goto(INVALID_LABEL) self.add(goto) top = self.new_block() goto.label = top.label branches = self.process_conditional(s.expr) body = self.new_block() # Bind "true" branches to the body block. self.set_branches(branches, True, body) s.body.accept(self) # Add branch to the top at the end of the body. self.add(Goto(top.label)) next = self.new_block() # Bind "false" branches to the new block. self.set_branches(branches, False, next) self.pop_loop_stack(top, next) return INVALID_REGISTER def visit_for_stmt(self, s: ForStmt) -> Register: if (isinstance(s.expr, CallExpr) and isinstance(s.expr.callee, RefExpr) and s.expr.callee.fullname == 'builtins.range'): self.push_loop_stack() # Special case for x in range(...) # TODO: Check argument counts and kinds; check the lvalue end = s.expr.args[0] end_reg = self.accept(end) # Initialize loop index to 0. index_reg = self.assign(s.index, IntExpr(0), IntRType(), IntRType(), declare_new=True) goto = Goto(INVALID_LABEL) self.add(goto) # Add loop condition check. top = self.new_block() goto.label = top.label branch = Branch(index_reg, end_reg, INVALID_LABEL, INVALID_LABEL, Branch.INT_LT) self.add(branch) branches = [branch] body = self.new_block() self.set_branches(branches, True, body) s.body.accept(self) end_goto = Goto(INVALID_LABEL) self.add(end_goto) end_block = self.new_block() end_goto.label = end_block.label # Increment index register. one_reg = self.alloc_temp(IntRType()) self.add(LoadInt(one_reg, 1)) self.add( PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg, one_reg)) # Go back to loop condition check. self.add(Goto(top.label)) next = self.new_block() self.set_branches(branches, False, next) self.pop_loop_stack(end_block, next) return INVALID_REGISTER if self.node_type(s.expr).name == 'list': self.push_loop_stack() expr_reg = self.accept(s.expr) index_reg = self.alloc_temp(IntRType()) self.add(LoadInt(index_reg, 0)) one_reg = self.alloc_temp(IntRType()) self.add(LoadInt(one_reg, 1)) assert isinstance(s.index, NameExpr) assert isinstance(s.index.node, Var) lvalue_reg = self.environment.add_local(s.index.node, self.node_type(s.index)) condition_block = self.goto_new_block() # For compatibility with python semantics we recalculate the length # at every iteration. len_reg = self.alloc_temp(IntRType()) self.add(PrimitiveOp(len_reg, PrimitiveOp.LIST_LEN, expr_reg)) branch = Branch(index_reg, len_reg, INVALID_LABEL, INVALID_LABEL, Branch.INT_LT) self.add(branch) branches = [branch] body_block = self.new_block() self.set_branches(branches, True, body_block) target_list_type = self.types[s.expr] assert isinstance(target_list_type, Instance) target_type = self.type_to_rtype(target_list_type.args[0]) value_box = self.alloc_temp(ObjectRType()) self.add( PrimitiveOp(value_box, PrimitiveOp.LIST_GET, expr_reg, index_reg)) self.unbox_or_cast(value_box, target_type, target=lvalue_reg) s.body.accept(self) end_block = self.goto_new_block() self.add( PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg, one_reg)) self.add(Goto(condition_block.label)) next_block = self.new_block() self.set_branches(branches, False, next_block) self.pop_loop_stack(end_block, next_block) return INVALID_REGISTER assert False, 'for not supported' def visit_break_stmt(self, node: BreakStmt) -> Register: self.break_gotos[-1].append(Goto(INVALID_LABEL)) self.add(self.break_gotos[-1][-1]) return INVALID_REGISTER def visit_continue_stmt(self, node: ContinueStmt) -> Register: self.continue_gotos[-1].append(Goto(INVALID_LABEL)) self.add(self.continue_gotos[-1][-1]) return INVALID_REGISTER int_binary_ops = { '+': PrimitiveOp.INT_ADD, '-': PrimitiveOp.INT_SUB, '*': PrimitiveOp.INT_MUL, '//': PrimitiveOp.INT_DIV, '%': PrimitiveOp.INT_MOD, '&': PrimitiveOp.INT_AND, '|': PrimitiveOp.INT_OR, '^': PrimitiveOp.INT_XOR, '<<': PrimitiveOp.INT_SHL, '>>': PrimitiveOp.INT_SHR, '>>': PrimitiveOp.INT_SHR, } def visit_unary_expr(self, expr: UnaryExpr) -> Register: if expr.op != '-': assert False, 'Unsupported unary operation' etype = self.node_type(expr.expr) reg = self.accept(expr.expr) if etype.name != 'int': assert False, 'Unsupported unary operation' target = self.alloc_target(IntRType()) zero = self.accept(IntExpr(0)) self.add(PrimitiveOp(target, PrimitiveOp.INT_SUB, zero, reg)) return target def visit_op_expr(self, expr: OpExpr) -> Register: ltype = self.node_type(expr.left) rtype = self.node_type(expr.right) lreg = self.accept(expr.left) rreg = self.accept(expr.right) return self.binary_op(ltype, lreg, rtype, rreg, expr.op) def binary_op(self, ltype: RType, lreg: Register, rtype: RType, rreg: Register, expr_op: str, target: Optional[Register] = None) -> Register: if ltype.name == 'int' and rtype.name == 'int': # Primitive int operation if target is None: target = self.alloc_target(IntRType()) op = self.int_binary_ops[expr_op] elif (ltype.name == 'list' or rtype.name == 'list') and expr_op == '*': if rtype.name == 'list': ltype, rtype = rtype, ltype lreg, rreg = rreg, lreg if rtype.name != 'int': assert False, 'Unsupported binary operation' # TODO: Operator overloading if target is None: target = self.alloc_target(ListRType()) op = PrimitiveOp.LIST_REPEAT elif isinstance(rtype, DictRType): if expr_op == 'in': if target is None: target = self.alloc_target(BoolRType()) lreg = self.box(lreg, ltype) op = PrimitiveOp.DICT_CONTAINS else: assert False, 'Unsupported binary operation' else: assert False, 'Unsupported binary operation' self.add(PrimitiveOp(target, op, lreg, rreg)) return target def visit_index_expr(self, expr: IndexExpr) -> Register: base_rtype = self.node_type(expr.base) base_reg = self.accept(expr.base) target_type = self.node_type(expr) if isinstance(base_rtype, (ListRType, SequenceTupleRType, DictRType)): index_type = self.node_type(expr.index) if not isinstance(base_rtype, DictRType): assert isinstance( index_type, IntRType), 'Unsupported indexing operation' # TODO if isinstance(base_rtype, ListRType): op = PrimitiveOp.LIST_GET elif isinstance(base_rtype, DictRType): op = PrimitiveOp.DICT_GET else: op = PrimitiveOp.HOMOGENOUS_TUPLE_GET index_reg = self.accept(expr.index) if isinstance(base_rtype, DictRType): index_reg = self.box(index_reg, index_type) tmp = self.alloc_temp(ObjectRType()) self.add(PrimitiveOp(tmp, op, base_reg, index_reg)) target = self.alloc_target(target_type) return self.unbox_or_cast(tmp, target_type, target) elif isinstance(base_rtype, TupleRType): assert isinstance(expr.index, IntExpr) # TODO target = self.alloc_target(target_type) self.add( TupleGet(target, base_reg, expr.index.value, base_rtype.types[expr.index.value])) return target assert False, 'Unsupported indexing operation' def visit_int_expr(self, expr: IntExpr) -> Register: reg = self.alloc_target(IntRType()) self.add(LoadInt(reg, expr.value)) return reg def is_native_name_expr(self, expr: NameExpr) -> bool: # TODO later we want to support cross-module native calls too if '.' in expr.node.fullname(): module_name = '.'.join(expr.node.fullname().split('.')[:-1]) return module_name == self.current_module_name return True def visit_name_expr(self, expr: NameExpr) -> Register: if expr.node.fullname() == 'builtins.None': target = self.alloc_target(NoneRType()) self.add(PrimitiveOp(target, PrimitiveOp.NONE)) return target elif expr.node.fullname() == 'builtins.True': target = self.alloc_target(BoolRType()) self.add(PrimitiveOp(target, PrimitiveOp.TRUE)) return target elif expr.node.fullname() == 'builtins.False': target = self.alloc_target(BoolRType()) self.add(PrimitiveOp(target, PrimitiveOp.FALSE)) return target if not self.is_native_name_expr(expr): return self.load_static_module_attr(expr) # TODO: We assume that this is a Var node, which is very limited assert isinstance(expr.node, Var) reg = self.environment.lookup(expr.node) return self.get_using_binder(reg, expr.node, expr) def get_using_binder(self, reg: Register, var: Var, expr: Expression) -> Register: var_type = self.type_to_rtype(var.type) target_type = self.node_type(expr) if var_type != target_type: # Cast/unbox to the narrower given by the binder. if self.targets[-1] < 0: target = self.alloc_temp(target_type) else: target = self.targets[-1] return self.unbox_or_cast(reg, target_type, target) else: # Regular register access -- binder is not active. if self.targets[-1] < 0: return reg else: target = self.targets[-1] self.add(Assign(target, reg)) return target def is_module_member_expr(self, expr: MemberExpr): return isinstance(expr.expr, RefExpr) and expr.expr.kind == MODULE_REF def visit_member_expr(self, expr: MemberExpr) -> Register: if self.is_module_member_expr(expr): return self.load_static_module_attr(expr) else: obj_reg = self.accept(expr.expr) attr_type = self.node_type(expr) target = self.alloc_target(attr_type) obj_type = self.node_type(expr.expr) assert isinstance( obj_type, UserRType), 'Attribute access not supported: %s' % obj_type self.add(GetAttr(target, obj_reg, expr.name, obj_type)) return target def load_static_module_attr(self, expr: RefExpr) -> Register: target = self.alloc_target(self.node_type(expr)) module = '.'.join(expr.node.fullname().split('.')[:-1]) right = expr.node.fullname().split('.')[-1] left = self.alloc_temp(ObjectRType()) self.add(LoadStatic(left, c_module_name(module))) self.add(PyGetAttr(target, left, right)) return target def py_call(self, function: Register, args: List[Expression], target_type: RType) -> Register: target_box = self.alloc_temp(ObjectRType()) arg_boxes = [] # type: List[Register] for arg_expr in args: arg_reg = self.accept(arg_expr) arg_boxes.append(self.box(arg_reg, self.node_type(arg_expr))) self.add(PyCall(target_box, function, arg_boxes)) return self.unbox_or_cast(target_box, target_type) def visit_call_expr(self, expr: CallExpr) -> Register: if isinstance(expr.callee, MemberExpr): is_module_call = self.is_module_member_expr(expr.callee) if expr.callee.expr in self.types and not is_module_call: target = self.translate_special_method_call(expr.callee, expr) if target: return target # Either its a module call or translating to a special method call failed, so we have # to fallback to a PyCall function = self.accept(expr.callee) return self.py_call(function, expr.args, self.node_type(expr)) assert isinstance(expr.callee, NameExpr) fn = expr.callee.name # TODO: fullname if fn == 'len' and len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: target = self.alloc_target(IntRType()) arg = self.accept(expr.args[0]) expr_rtype = self.node_type(expr.args[0]) if expr_rtype.name == 'list': self.add(PrimitiveOp(target, PrimitiveOp.LIST_LEN, arg)) elif expr_rtype.name == 'sequence_tuple': self.add( PrimitiveOp(target, PrimitiveOp.HOMOGENOUS_TUPLE_LEN, arg)) elif isinstance(expr_rtype, TupleRType): self.add(LoadInt(target, len(expr_rtype.types))) else: assert False, "unsupported use of len" # Handle conversion to sequence tuple elif fn == 'tuple' and len( expr.args) == 1 and expr.arg_kinds == [ARG_POS]: target = self.alloc_target(SequenceTupleRType()) arg = self.accept(expr.args[0]) self.add( PrimitiveOp(target, PrimitiveOp.LIST_TO_HOMOGENOUS_TUPLE, arg)) else: target_type = self.node_type(expr) if not (self.is_native_name_expr(expr.callee)): function = self.accept(expr.callee) return self.py_call(function, expr.args, target_type) target = self.alloc_target(target_type) args = [self.accept(arg) for arg in expr.args] self.add(Call(target, fn, args)) return target def visit_conditional_expr(self, expr: ConditionalExpr) -> Register: branches = self.process_conditional(expr.cond) target = self.alloc_target(self.node_type(expr)) if_body = self.new_block() self.set_branches(branches, True, if_body) self.accept(expr.if_expr, target=target) if_goto_next = Goto(INVALID_LABEL) self.add(if_goto_next) else_body = self.new_block() self.set_branches(branches, False, else_body) self.accept(expr.else_expr, target=target) else_goto_next = Goto(INVALID_LABEL) self.add(else_goto_next) next = self.new_block() if_goto_next.label = next.label else_goto_next.label = next.label return target def translate_special_method_call(self, callee: MemberExpr, expr: CallExpr) -> Register: base_type = self.node_type(callee.expr) result_type = self.node_type(expr) base = self.accept(callee.expr) if callee.name == 'append' and base_type.name == 'list': target = INVALID_REGISTER # TODO: Do we sometimes need to allocate a register? arg = self.box_expr(expr.args[0]) self.add(PrimitiveOp(target, PrimitiveOp.LIST_APPEND, base, arg)) else: assert False, 'Unsupported method call: %s.%s' % (base_type.name, callee.name) return target def visit_list_expr(self, expr: ListExpr) -> Register: list_type = self.types[expr] assert isinstance(list_type, Instance) item_type = self.type_to_rtype(list_type.args[0]) target = self.alloc_target(ListRType()) items = [] for item in expr.items: item_reg = self.accept(item) boxed = self.box(item_reg, item_type) items.append(boxed) self.add(PrimitiveOp(target, PrimitiveOp.NEW_LIST, *items)) return target def visit_tuple_expr(self, expr: TupleExpr) -> Register: tuple_type = self.types[expr] assert isinstance(tuple_type, TupleType) target = self.alloc_target(self.type_to_rtype(tuple_type)) items = [self.accept(i) for i in expr.items] self.add(PrimitiveOp(target, PrimitiveOp.NEW_TUPLE, *items)) return target def visit_dict_expr(self, expr: DictExpr): assert not expr.items # TODO target = self.alloc_target(DictRType()) self.add(PrimitiveOp(target, PrimitiveOp.NEW_DICT)) return target # Conditional expressions int_relative_ops = { '==': Branch.INT_EQ, '!=': Branch.INT_NE, '<': Branch.INT_LT, '<=': Branch.INT_LE, '>': Branch.INT_GT, '>=': Branch.INT_GE, } def process_conditional(self, e: Node) -> List[Branch]: if isinstance(e, ComparisonExpr): # TODO: Verify operand types. assert len(e.operators) == 1, 'more than 1 operator not supported' op = e.operators[0] if op in ['==', '!=', '<', '<=', '>', '>=']: # TODO: check operand types left = self.accept(e.operands[0]) right = self.accept(e.operands[1]) opcode = self.int_relative_ops[op] branch = Branch(left, right, INVALID_LABEL, INVALID_LABEL, opcode) elif op in ['is', 'is not']: # TODO: check if right operand is None left = self.accept(e.operands[0]) branch = Branch(left, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.IS_NONE) if op == 'is not': branch.negated = True elif op in ['in', 'not in']: left = self.accept(e.operands[0]) ltype = self.node_type(e.operands[0]) right = self.accept(e.operands[1]) rtype = self.node_type(e.operands[1]) target = self.alloc_temp(self.node_type(e)) self.binary_op(ltype, left, rtype, right, 'in', target=target) branch = Branch(target, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.BOOL_EXPR) if op == 'not in': branch.negated = True else: assert False, "unsupported comparison epxression" self.add(branch) return [branch] elif isinstance(e, OpExpr) and e.op in ['and', 'or']: if e.op == 'and': # Short circuit 'and' in a conditional context. lbranches = self.process_conditional(e.left) new = self.new_block() self.set_branches(lbranches, True, new) rbranches = self.process_conditional(e.right) return lbranches + rbranches else: # Short circuit 'or' in a conditional context. lbranches = self.process_conditional(e.left) new = self.new_block() self.set_branches(lbranches, False, new) rbranches = self.process_conditional(e.right) return lbranches + rbranches elif isinstance(e, UnaryExpr) and e.op == 'not': branches = self.process_conditional(e.expr) for b in branches: b.invert() return branches # Catch-all for arbitrary expressions. else: reg = self.accept(e) branch = Branch(reg, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.BOOL_EXPR) self.add(branch) return [branch] def set_branches(self, branches: List[Branch], condition: bool, target: BasicBlock) -> None: """Set branch targets for the given condition (True or False). If the target has already been set for a branch, skip the branch. """ for b in branches: if condition: if b.true < 0: b.true = target.label else: if b.false < 0: b.false = target.label # Helpers def enter(self) -> None: self.environment = Environment() self.environments.append(self.environment) self.blocks.append([]) self.new_block() def new_block(self) -> BasicBlock: new = BasicBlock(Label(len(self.blocks[-1]))) self.blocks[-1].append(new) return new def goto_new_block(self) -> BasicBlock: goto = Goto(INVALID_LABEL) self.add(goto) block = self.new_block() goto.label = block.label return block def leave(self) -> Tuple[List[BasicBlock], Environment]: blocks = self.blocks.pop() env = self.environments.pop() self.environment = self.environments[-1] return blocks, env def add(self, op: Op) -> None: self.blocks[-1][-1].ops.append(op) def accept(self, node: Node, target: Register = INVALID_REGISTER) -> Register: self.targets.append(target) actual = node.accept(self) self.targets.pop() return actual def alloc_target(self, type: RType) -> Register: if self.targets[-1] < 0: return self.environment.add_temp(type) else: return self.targets[-1] def alloc_temp(self, type: RType) -> Register: return self.environment.add_temp(type) def type_to_rtype(self, typ: Type) -> RType: return self.mapper.type_to_rtype(typ) def node_type(self, node: Expression) -> RType: mypy_type = self.types[node] return self.type_to_rtype(mypy_type) def box(self, src: Register, typ: RType, target: Optional[Register] = None) -> Register: if typ.supports_unbox: if target is None: target = self.alloc_temp(ObjectRType()) self.add(Box(target, src, typ)) return target else: # Already boxed if target is not None: self.add(Assign(target, src)) return target else: return src def unbox_or_cast(self, src: Register, target_type: RType, target: Optional[Register] = None) -> Register: if target is None: target = self.alloc_temp(target_type) if target_type.supports_unbox: self.add(Unbox(target, src, target_type)) else: self.add(Cast(target, src, target_type)) return target def box_expr(self, expr: Expression) -> Register: typ = self.node_type(expr) return self.box(self.accept(expr), typ)
class TestFunctionEmitterVisitor(unittest.TestCase): def setUp(self) -> None: self.env = Environment() self.n = self.env.add_local(Var('n'), IntRType()) self.m = self.env.add_local(Var('m'), IntRType()) self.k = self.env.add_local(Var('k'), IntRType()) self.l = self.env.add_local(Var('l'), ListRType()) self.ll = self.env.add_local(Var('ll'), ListRType()) self.o = self.env.add_local(Var('o'), ObjectRType()) self.o2 = self.env.add_local(Var('o2'), ObjectRType()) self.d = self.env.add_local(Var('d'), DictRType()) self.b = self.env.add_local(Var('b'), BoolRType()) self.context = EmitterContext() self.emitter = Emitter(self.context, self.env) self.declarations = Emitter(self.context, self.env) self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations) def test_goto(self) -> None: self.assert_emit(Goto(Label(2)), "goto CPyL2;") def test_return(self) -> None: self.assert_emit(Return(self.m), "return cpy_r_m;") def test_load_int(self) -> None: self.assert_emit(LoadInt(self.m, 5), "cpy_r_m = 10;") def test_tuple_get(self) -> None: self.assert_emit(TupleGet(self.m, self.n, 1, BoolRType()), 'cpy_r_m = cpy_r_n.f1;') def test_load_None(self) -> None: self.assert_emit( PrimitiveOp(self.m, PrimitiveOp.NONE), """cpy_r_m = Py_None; Py_INCREF(cpy_r_m); """) def test_load_True(self) -> None: self.assert_emit(PrimitiveOp(self.m, PrimitiveOp.TRUE), "cpy_r_m = 1;") def test_load_False(self) -> None: self.assert_emit(PrimitiveOp(self.m, PrimitiveOp.FALSE), "cpy_r_m = 0;") def test_assign_int(self) -> None: self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;") def test_int_add(self) -> None: self.assert_emit( PrimitiveOp(self.n, PrimitiveOp.INT_ADD, self.m, self.k), "cpy_r_n = CPyTagged_Add(cpy_r_m, cpy_r_k);") def test_int_sub(self) -> None: self.assert_emit( PrimitiveOp(self.n, PrimitiveOp.INT_SUB, self.m, self.k), "cpy_r_n = CPyTagged_Subtract(cpy_r_m, cpy_r_k);") def test_list_repeat(self) -> None: self.assert_emit( PrimitiveOp(self.ll, PrimitiveOp.LIST_REPEAT, self.l, self.n), """long long __tmp1; __tmp1 = CPyTagged_AsLongLong(cpy_r_n); if (__tmp1 == -1 && PyErr_Occurred()) abort(); cpy_r_ll = PySequence_Repeat(cpy_r_l, __tmp1); if (!cpy_r_ll) abort(); """) def test_int_neg(self) -> None: self.assert_emit(PrimitiveOp(self.n, PrimitiveOp.INT_NEG, self.m), "cpy_r_n = CPy_NegateInt(cpy_r_m);") def test_list_len(self) -> None: self.assert_emit( PrimitiveOp(self.n, PrimitiveOp.LIST_LEN, self.l), """long long __tmp1; __tmp1 = PyList_GET_SIZE(cpy_r_l); cpy_r_n = CPyTagged_ShortFromLongLong(__tmp1); """) def test_branch_eq(self) -> None: self.assert_emit( Branch(self.n, self.m, Label(8), Label(9), Branch.INT_EQ), """if (CPyTagged_IsEq(cpy_r_n, cpy_r_m)) goto CPyL8; else goto CPyL9; """) b = Branch(self.n, self.m, Label(8), Label(9), Branch.INT_LT) b.negated = True self.assert_emit( b, """if (!CPyTagged_IsLt(cpy_r_n, cpy_r_m)) goto CPyL8; else goto CPyL9; """) def test_call(self) -> None: self.assert_emit(Call(self.n, 'myfn', [self.m]), "cpy_r_n = CPyDef_myfn(cpy_r_m);") def test_call_two_args(self) -> None: self.assert_emit(Call(self.n, 'myfn', [self.m, self.k]), "cpy_r_n = CPyDef_myfn(cpy_r_m, cpy_r_k);") def test_call_no_return(self) -> None: self.assert_emit(Call(None, 'myfn', [self.m, self.k]), "CPyDef_myfn(cpy_r_m, cpy_r_k);") def test_inc_ref(self) -> None: self.assert_emit(IncRef(self.m, IntRType()), "CPyTagged_IncRef(cpy_r_m);") def test_dec_ref(self) -> None: self.assert_emit(DecRef(self.m, IntRType()), "CPyTagged_DecRef(cpy_r_m);") def test_dec_ref_tuple(self) -> None: tuple_type = TupleRType([IntRType(), BoolRType()]) self.assert_emit(DecRef(self.m, tuple_type), 'CPyTagged_DecRef(cpy_r_m.f0);') def test_dec_ref_tuple_nested(self) -> None: tuple_type = TupleRType( [TupleRType([IntRType(), BoolRType()]), BoolRType()]) self.assert_emit(DecRef(self.m, tuple_type), 'CPyTagged_DecRef(cpy_r_m.f0.f0);') def test_list_get_item(self) -> None: self.assert_emit( PrimitiveOp(self.n, PrimitiveOp.LIST_GET, self.m, self.k), """cpy_r_n = CPyList_GetItem(cpy_r_m, cpy_r_k); if (!cpy_r_n) abort(); """) def test_list_set_item(self) -> None: self.assert_emit( PrimitiveOp(None, PrimitiveOp.LIST_SET, self.l, self.n, self.o), """if (!CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o)) abort(); """) def test_box(self) -> None: self.assert_emit(Box(self.o, self.n, IntRType()), """cpy_r_o = CPyTagged_StealAsObject(cpy_r_n);""") def test_unbox(self) -> None: self.assert_emit( Unbox(self.n, self.m, IntRType()), """if (PyLong_Check(cpy_r_m)) cpy_r_n = CPyTagged_FromObject(cpy_r_m); else abort(); """) def test_new_list(self) -> None: self.assert_emit( PrimitiveOp(self.l, PrimitiveOp.NEW_LIST, self.n, self.m), """cpy_r_l = PyList_New(2); Py_INCREF(cpy_r_n); PyList_SET_ITEM(cpy_r_l, 0, cpy_r_n); Py_INCREF(cpy_r_m); PyList_SET_ITEM(cpy_r_l, 1, cpy_r_m); """) def test_list_append(self) -> None: self.assert_emit( PrimitiveOp(None, PrimitiveOp.LIST_APPEND, self.l, self.o), """if (PyList_Append(cpy_r_l, cpy_r_o) == -1) abort(); """) def test_get_attr(self) -> None: ir = ClassIR('A', [('x', BoolRType()), ('y', IntRType())]) rtype = UserRType(ir) self.assert_emit( GetAttr(self.n, self.m, 'y', rtype), """cpy_r_n = CPY_GET_ATTR(cpy_r_m, 2, AObject, CPyTagged);""") def test_set_attr(self) -> None: ir = ClassIR('A', [('x', BoolRType()), ('y', IntRType())]) rtype = UserRType(ir) self.assert_emit( SetAttr(self.n, 'y', self.m, rtype), """CPY_SET_ATTR(cpy_r_n, 3, cpy_r_m, AObject, CPyTagged);""") def test_dict_get_item(self) -> None: self.assert_emit( PrimitiveOp(self.o, PrimitiveOp.DICT_GET, self.d, self.o2), """cpy_r_o = PyDict_GetItem(cpy_r_d, cpy_r_o2); if (!cpy_r_o) abort(); Py_INCREF(cpy_r_o); """) def test_dict_set_item(self) -> None: self.assert_emit( PrimitiveOp(None, PrimitiveOp.DICT_SET, self.d, self.o, self.o2), """if (PyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2) < 0) abort(); """) def test_new_dict(self) -> None: self.assert_emit( PrimitiveOp(self.d, PrimitiveOp.NEW_DICT), """cpy_r_d = PyDict_New(); if (!cpy_r_d) abort(); """) def test_dict_contains(self) -> None: self.assert_emit( PrimitiveOp(self.b, PrimitiveOp.DICT_CONTAINS, self.o, self.d), """int __tmp1 = PyDict_Contains(cpy_r_d, cpy_r_o); if (__tmp1 < 0) abort(); cpy_r_b = __tmp1; """) def assert_emit(self, op: Op, expected: str) -> None: self.emitter.fragments = [] self.declarations.fragments = [] op.accept(self.visitor) frags = self.declarations.fragments + self.emitter.fragments actual_lines = [line.strip(' ') for line in frags] assert all(line.endswith('\n') for line in actual_lines) actual_lines = [line.rstrip('\n') for line in actual_lines] expected_lines = expected.rstrip().split('\n') expected_lines = [line.strip(' ') for line in expected_lines] assert_string_arrays_equal(expected_lines, actual_lines, msg='Generated code unexpected')