Пример #1
0
class TestGenerateFunction(unittest.TestCase):
    def setUp(self) -> None:
        self.var = Var('arg')
        self.arg = RuntimeArg('arg', int_rprimitive)
        self.env = Environment()
        self.reg = self.env.add_local(self.var, int_rprimitive)
        self.block = BasicBlock(0)

    def test_simple(self) -> None:
        self.block.ops.append(Return(self.reg))
        fn = FuncIR(
            FuncDecl('myfunc', None, 'mod',
                     FuncSignature([self.arg], int_rprimitive)), [self.block],
            self.env)
        emitter = Emitter(EmitterContext(NameGenerator([['mod']])))
        generate_native_function(fn, emitter, 'prog.py', 'prog')
        result = emitter.fragments
        assert_string_arrays_equal([
            'CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            'CPyL0: ;\n',
            '    return cpy_r_arg;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')

    def test_register(self) -> None:
        self.env.temp_index = 0
        op = LoadInt(5)
        self.block.ops.append(op)
        self.env.add_op(op)
        fn = FuncIR(
            FuncDecl('myfunc', None, 'mod',
                     FuncSignature([self.arg], list_rprimitive)), [self.block],
            self.env)
        emitter = Emitter(EmitterContext(NameGenerator([['mod']])))
        generate_native_function(fn, emitter, 'prog.py', 'prog')
        result = emitter.fragments
        assert_string_arrays_equal([
            'PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            '    CPyTagged cpy_r_r0;\n',
            'CPyL0: ;\n',
            '    cpy_r_r0 = 10;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')
Пример #2
0
class TestEmitter(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), int_rprimitive)
        self.context = EmitterContext(NameGenerator([['mod']]))
        self.emitter = Emitter(self.context, self.env)

    def test_label(self) -> None:
        assert self.emitter.label(BasicBlock(4)) == 'CPyL4'

    def test_reg(self) -> None:
        assert self.emitter.reg(self.n) == 'cpy_r_n'

    def test_emit_line(self) -> None:
        self.emitter.emit_line('line;')
        self.emitter.emit_line('a {')
        self.emitter.emit_line('f();')
        self.emitter.emit_line('}')
        assert self.emitter.fragments == [
            'line;\n', 'a {\n', '    f();\n', '}\n'
        ]
Пример #3
0
class TestEmitter(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), IntRType())
        self.context = EmitterContext()
        self.emitter = Emitter(self.context, self.env)

    def test_label(self) -> None:
        assert self.emitter.label(Label(4)) == 'CPyL4'

    def test_reg(self) -> None:
        assert self.emitter.reg(self.n) == 'cpy_r_n'

    def test_emit_line(self) -> None:
        self.emitter.emit_line('line;')
        self.emitter.emit_line('a {')
        self.emitter.emit_line('f();')
        self.emitter.emit_line('}')
        assert self.emitter.fragments == [
            'line;\n', 'a {\n', '    f();\n', '}\n'
        ]
Пример #4
0
class TestGenerateFunction(unittest.TestCase):
    def setUp(self) -> None:
        self.var = Var('arg')
        self.arg = RuntimeArg('arg', IntRType())
        self.env = Environment()
        self.reg = self.env.add_local(self.var, IntRType())
        self.block = BasicBlock(Label(0))

    def test_simple(self) -> None:
        self.block.ops.append(Return(self.reg))
        fn = FuncIR('myfunc', [self.arg], IntRType(), [self.block], self.env)
        emitter = Emitter(EmitterContext())
        generate_native_function(fn, emitter)
        result = emitter.fragments
        assert_string_arrays_equal([
            'static CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            'CPyL0: ;\n',
            '    return cpy_r_arg;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')

    def test_register(self) -> None:
        self.temp = self.env.add_temp(IntRType())
        self.block.ops.append(LoadInt(self.temp, 5))
        fn = FuncIR('myfunc', [self.arg], ListRType(), [self.block], self.env)
        emitter = Emitter(EmitterContext())
        generate_native_function(fn, emitter)
        result = emitter.fragments
        assert_string_arrays_equal([
            'static PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            '    CPyTagged cpy_r_r0;\n',
            'CPyL0: ;\n',
            '    cpy_r_r0 = 10;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')
Пример #5
0
class TestFunctionEmitterVisitor(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), int_rprimitive)
        self.m = self.env.add_local(Var('m'), int_rprimitive)
        self.k = self.env.add_local(Var('k'), int_rprimitive)
        self.l = self.env.add_local(Var('l'), list_rprimitive)  # noqa
        self.ll = self.env.add_local(Var('ll'), list_rprimitive)
        self.o = self.env.add_local(Var('o'), object_rprimitive)
        self.o2 = self.env.add_local(Var('o2'), object_rprimitive)
        self.d = self.env.add_local(Var('d'), dict_rprimitive)
        self.b = self.env.add_local(Var('b'), bool_rprimitive)
        self.t = self.env.add_local(Var('t'),
                                    RTuple([int_rprimitive, bool_rprimitive]))
        self.tt = self.env.add_local(
            Var('tt'),
            RTuple(
                [RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive]))
        ir = ClassIR('A', 'mod')
        ir.attributes = OrderedDict([('x', bool_rprimitive),
                                     ('y', int_rprimitive)])
        compute_vtable(ir)
        ir.mro = [ir]
        self.r = self.env.add_local(Var('r'), RInstance(ir))

        self.context = EmitterContext(NameGenerator([['mod']]))
        self.emitter = Emitter(self.context, self.env)
        self.declarations = Emitter(self.context, self.env)
        self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations,
                                              'prog.py', 'prog')

    def test_goto(self) -> None:
        self.assert_emit(Goto(BasicBlock(2)), "goto CPyL2;")

    def test_return(self) -> None:
        self.assert_emit(Return(self.m), "return cpy_r_m;")

    def test_load_int(self) -> None:
        self.assert_emit(LoadInt(5), "cpy_r_r0 = 10;")

    def test_tuple_get(self) -> None:
        self.assert_emit(TupleGet(self.t, 1, 0), 'cpy_r_r0 = cpy_r_t.f1;')

    def test_load_None(self) -> None:
        self.assert_emit(PrimitiveOp([], none_object_op, 0),
                         "cpy_r_r0 = Py_None;")

    def test_load_True(self) -> None:
        self.assert_emit(PrimitiveOp([], true_op, 0), "cpy_r_r0 = 1;")

    def test_load_False(self) -> None:
        self.assert_emit(PrimitiveOp([], false_op, 0), "cpy_r_r0 = 0;")

    def test_assign_int(self) -> None:
        self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;")

    def test_int_add(self) -> None:
        self.assert_emit_binary_op(
            '+', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);")

    def test_int_sub(self) -> None:
        self.assert_emit_binary_op(
            '-', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);")

    def test_list_repeat(self) -> None:
        self.assert_emit_binary_op(
            '*', self.ll, self.l, self.n, """Py_ssize_t __tmp1;
               __tmp1 = CPyTagged_AsSsize_t(cpy_r_n);
               if (__tmp1 == -1 && PyErr_Occurred())
                   CPyError_OutOfMemory();
               cpy_r_r0 = PySequence_Repeat(cpy_r_l, __tmp1);
            """)

    def test_int_neg(self) -> None:
        self.assert_emit(PrimitiveOp([self.m], int_neg_op, 55),
                         "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);")

    def test_list_len(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l], list_len_op, 55), """Py_ssize_t __tmp1;
                            __tmp1 = PyList_GET_SIZE(cpy_r_l);
                            cpy_r_r0 = CPyTagged_ShortFromSsize_t(__tmp1);
                         """)

    def test_branch(self) -> None:
        self.assert_emit(
            Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR),
            """if (cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)
        b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR)
        b.negated = True
        self.assert_emit(
            b, """if (!cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)

    def test_call(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([RuntimeArg('m', int_rprimitive)], int_rprimitive))
        self.assert_emit(Call(decl, [self.m], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m);")

    def test_call_two_args(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([
                RuntimeArg('m', int_rprimitive),
                RuntimeArg('n', int_rprimitive)
            ], int_rprimitive))
        self.assert_emit(Call(decl, [self.m, self.k], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);")

    def test_inc_ref(self) -> None:
        self.assert_emit(IncRef(self.m), "CPyTagged_IncRef(cpy_r_m);")

    def test_dec_ref(self) -> None:
        self.assert_emit(DecRef(self.m), "CPyTagged_DecRef(cpy_r_m);")

    def test_dec_ref_tuple(self) -> None:
        self.assert_emit(DecRef(self.t), 'CPyTagged_DecRef(cpy_r_t.f0);')

    def test_dec_ref_tuple_nested(self) -> None:
        self.assert_emit(DecRef(self.tt), 'CPyTagged_DecRef(cpy_r_tt.f0.f0);')

    def test_list_get_item(self) -> None:
        self.assert_emit(PrimitiveOp([self.m, self.k], list_get_item_op, 55),
                         """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""")

    def test_list_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l, self.n, self.o], list_set_item_op, 55),
            """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""")

    def test_box(self) -> None:
        self.assert_emit(Box(self.n),
                         """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""")

    def test_unbox(self) -> None:
        self.assert_emit(
            Unbox(self.m, int_rprimitive, 55),
            """if (likely(PyLong_Check(cpy_r_m)))
                                cpy_r_r0 = CPyTagged_FromObject(cpy_r_m);
                            else {
                                CPy_TypeError("int", cpy_r_m);
                                cpy_r_r0 = CPY_INT_TAG;
                            }
                         """)

    def test_new_list(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.n, self.m], new_list_op, 55),
            """cpy_r_r0 = PyList_New(2);
                            if (likely(cpy_r_r0 != NULL)) {
                                PyList_SET_ITEM(cpy_r_r0, 0, cpy_r_n);
                                PyList_SET_ITEM(cpy_r_r0, 1, cpy_r_m);
                            }
                         """)

    def test_list_append(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l, self.o], list_append_op, 1),
            """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o) >= 0;""")

    def test_get_attr(self) -> None:
        self.assert_emit(
            GetAttr(self.r, 'y', 1),
            """cpy_r_r0 = native_A_gety((AObject *)cpy_r_r); /* y */""")

    def test_set_attr(self) -> None:
        self.assert_emit(
            SetAttr(self.r, 'y', self.m, 1),
            "cpy_r_r0 = native_A_sety((AObject *)cpy_r_r, cpy_r_m); /* y */")

    def test_dict_get_item(self) -> None:
        self.assert_emit(PrimitiveOp([self.d, self.o2], dict_get_item_op, 1),
                         """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""")

    def test_dict_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.d, self.o, self.o2], dict_set_item_op, 1),
            """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2) >= 0;""")

    def test_dict_update(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.d, self.o], dict_update_op, 1),
            """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o) >= 0;""")

    def test_new_dict(self) -> None:
        self.assert_emit(PrimitiveOp([], new_dict_op, 1),
                         """cpy_r_r0 = PyDict_New();""")

    def test_dict_contains(self) -> None:
        self.assert_emit_binary_op(
            'in', self.b, self.o, self.d,
            """int __tmp1 = PyDict_Contains(cpy_r_d, cpy_r_o);
               if (__tmp1 < 0)
                   cpy_r_r0 = 2;
               else
                   cpy_r_r0 = __tmp1;
            """)

    def assert_emit(self, op: Op, expected: str) -> None:
        self.emitter.fragments = []
        self.declarations.fragments = []
        self.env.temp_index = 0
        if isinstance(op, RegisterOp):
            self.env.add_op(op)
        op.accept(self.visitor)
        frags = self.declarations.fragments + self.emitter.fragments
        actual_lines = [line.strip(' ') for line in frags]
        assert all(line.endswith('\n') for line in actual_lines)
        actual_lines = [line.rstrip('\n') for line in actual_lines]
        expected_lines = expected.rstrip().split('\n')
        expected_lines = [line.strip(' ') for line in expected_lines]
        assert_string_arrays_equal(expected_lines,
                                   actual_lines,
                                   msg='Generated code unexpected')

    def assert_emit_binary_op(self, op: str, dest: Value, left: Value,
                              right: Value, expected: str) -> None:
        ops = binary_ops[op]
        for desc in ops:
            if (is_subtype(left.type, desc.arg_types[0])
                    and is_subtype(right.type, desc.arg_types[1])):
                self.assert_emit(PrimitiveOp([left, right], desc, 55),
                                 expected)
                break
        else:
            assert False, 'Could not find matching op'
Пример #6
0
class IRBuilder(NodeVisitor[Register]):
    def __init__(self, types: Dict[Expression, Type], mapper: Mapper) -> None:
        self.types = types
        self.environment = Environment()
        self.environments = [self.environment]
        self.blocks = []  # type: List[List[BasicBlock]]
        self.functions = []  # type: List[FuncIR]
        self.classes = []  # type: List[ClassIR]
        self.targets = []  # type: List[Register]

        # These lists operate as stack frames for loops. Each loop adds a new
        # frame (i.e. adds a new empty list [] to the outermost list). Each
        # break or continue is inserted within that frame as they are visited
        # and at the end of the loop the stack is popped and any break/continue
        # gotos have their targets rewritten to the next basic block.
        self.break_gotos = []  # type: List[List[Goto]]
        self.continue_gotos = []  # type: List[List[Goto]]

        self.mapper = mapper
        self.imports = []  # type: List[str]

        self.current_module_name = None  # type: Optional[str]

    def visit_mypy_file(self, mypyfile: MypyFile) -> Register:
        if mypyfile.fullname() in ('typing', 'abc'):
            # These module are special; their contents are currently all
            # built-in primitives.
            return INVALID_REGISTER

        # First pass: Build ClassIRs and TypeInfo-to-ClassIR mapping.
        for node in mypyfile.defs:
            if isinstance(node, ClassDef):
                self.prepare_class_def(node)

        # Second pass: Generate ops.
        self.current_module_name = mypyfile.fullname()
        for node in mypyfile.defs:
            node.accept(self)

        return INVALID_REGISTER

    def prepare_class_def(self, cdef: ClassDef) -> None:
        ir = ClassIR(cdef.name,
                     [])  # Populate attributes later in visit_class_def
        self.classes.append(ir)
        self.mapper.type_to_ir[cdef.info] = ir

    def visit_class_def(self, cdef: ClassDef) -> Register:
        attributes = []
        for name, node in cdef.info.names.items():
            if isinstance(node.node, Var):
                attributes.append((name, self.type_to_rtype(node.node.type)))
        ir = self.mapper.type_to_ir[cdef.info]
        ir.attributes = attributes
        return INVALID_REGISTER

    def visit_import(self, node: Import) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        for node_id, _ in node.ids:
            self.imports.append(node_id)

        return INVALID_REGISTER

    def visit_import_from(self, node: ImportFrom) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        self.imports.append(node.id)

        return INVALID_REGISTER

    def visit_import_all(self, node: ImportAll) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        self.imports.append(node.id)

        return INVALID_REGISTER

    def visit_func_def(self, fdef: FuncDef) -> Register:
        self.enter()

        for arg in fdef.arguments:
            self.environment.add_local(arg.variable,
                                       self.type_to_rtype(arg.variable.type))
        fdef.body.accept(self)

        ret_type = self.convert_return_type(fdef)
        if ret_type.name == 'None':
            self.add_implicit_return()
        else:
            self.add_implicit_unreachable()

        blocks, env = self.leave()
        args = self.convert_args(fdef)
        func = FuncIR(fdef.name(), args, ret_type, blocks, env)
        self.functions.append(func)
        return INVALID_REGISTER

    def convert_args(self, fdef: FuncDef) -> List[RuntimeArg]:
        assert isinstance(fdef.type, CallableType)
        ann = fdef.type
        return [
            RuntimeArg(arg.variable.name(),
                       self.type_to_rtype(ann.arg_types[i]))
            for i, arg in enumerate(fdef.arguments)
        ]

    def convert_return_type(self, fdef: FuncDef) -> RType:
        assert isinstance(fdef.type, CallableType)
        return self.type_to_rtype(fdef.type.ret_type)

    def add_implicit_return(self) -> None:
        block = self.blocks[-1][-1]
        if not block.ops or not isinstance(block.ops[-1], Return):
            retval = self.environment.add_temp(NoneRType())
            self.add(PrimitiveOp(retval, PrimitiveOp.NONE))
            self.add(Return(retval))

    def add_implicit_unreachable(self) -> None:
        block = self.blocks[-1][-1]
        if not block.ops or not isinstance(block.ops[-1], Return):
            self.add(Unreachable())

    def visit_block(self, block: Block) -> Register:
        for stmt in block.body:
            stmt.accept(self)
        return INVALID_REGISTER

    def visit_expression_stmt(self, stmt: ExpressionStmt) -> Register:
        self.accept(stmt.expr)
        return INVALID_REGISTER

    def visit_return_stmt(self, stmt: ReturnStmt) -> Register:
        if stmt.expr:
            retval = self.accept(stmt.expr)
        else:
            retval = self.environment.add_temp(NoneRType())
            self.add(PrimitiveOp(retval, PrimitiveOp.NONE))
        self.add(Return(retval))
        return INVALID_REGISTER

    def visit_assignment_stmt(self, stmt: AssignmentStmt) -> Register:
        assert len(stmt.lvalues) == 1
        lvalue = stmt.lvalues[0]
        if stmt.type:
            lvalue_type = self.type_to_rtype(stmt.type)
        else:
            if isinstance(lvalue, IndexExpr):
                # TODO: This won't be right for user-defined classes. Store the
                #     lvalue type in mypy and remove this special case.
                lvalue_type = ObjectRType()
            else:
                lvalue_type = self.node_type(lvalue)
        rvalue_type = self.node_type(stmt.rvalue)
        return self.assign(lvalue,
                           stmt.rvalue,
                           rvalue_type,
                           lvalue_type,
                           declare_new=(stmt.type is not None))

    def visit_operator_assignment_stmt(
            self, stmt: OperatorAssignmentStmt) -> Register:
        target = self.get_assignment_target(stmt.lvalue, declare_new=False)

        if isinstance(target, AssignmentTargetRegister):
            ltype = self.environment.types[target.register]
            rtype = self.node_type(stmt.rvalue)
            rreg = self.accept(stmt.rvalue)
            return self.binary_op(ltype,
                                  target.register,
                                  rtype,
                                  rreg,
                                  stmt.op,
                                  target=target.register)

        # NOTE: List index not supported yet for compound assignments.
        assert False, 'Unsupported lvalue: %r'

    def get_assignment_target(self, lvalue: Lvalue,
                              declare_new: bool) -> AssignmentTarget:
        if isinstance(lvalue, NameExpr):
            # Assign to local variable.
            assert lvalue.kind == LDEF
            if lvalue.is_def or declare_new:
                # Define a new variable.
                assert isinstance(lvalue.node, Var)  # TODO: Can this fail?
                lvalue_num = self.environment.add_local(
                    lvalue.node, self.node_type(lvalue))
            else:
                # Assign to a previously defined variable.
                assert isinstance(lvalue.node, Var)  # TODO: Can this fail?
                lvalue_num = self.environment.lookup(lvalue.node)

            return AssignmentTargetRegister(lvalue_num)
        elif isinstance(lvalue, IndexExpr):
            # Indexed assignment x[y] = e
            base_type = self.node_type(lvalue.base)
            index_type = self.node_type(lvalue.index)
            base_reg = self.accept(lvalue.base)
            index_reg = self.accept(lvalue.index)
            if isinstance(base_type, ListRType) and isinstance(
                    index_type, IntRType):
                # Indexed list set
                return AssignmentTargetIndex(base_reg, index_reg, base_type)
            elif isinstance(base_type, DictRType):
                # Indexed dict set
                boxed_index = self.box(index_reg, index_type)
                return AssignmentTargetIndex(base_reg, boxed_index, base_type)
        elif isinstance(lvalue, MemberExpr):
            # Attribute assignment x.y = e
            obj_type = self.node_type(lvalue.expr)
            assert isinstance(
                obj_type,
                UserRType), 'Attribute set only supported for user types'
            obj_reg = self.accept(lvalue.expr)
            return AssignmentTargetAttr(obj_reg, lvalue.name, obj_type)

        assert False, 'Unsupported lvalue: %r' % lvalue

    def assign_to_target(self, target: AssignmentTarget, rvalue: Expression,
                         rvalue_type: RType, needs_box: bool) -> Register:
        rvalue_type = rvalue_type or self.node_type(rvalue)

        if isinstance(target, AssignmentTargetRegister):
            if needs_box:
                unboxed = self.accept(rvalue)
                return self.box(unboxed, rvalue_type, target=target.register)
            else:
                return self.accept(rvalue, target=target.register)
        elif isinstance(target, AssignmentTargetAttr):
            rvalue_reg = self.accept(rvalue)
            if needs_box:
                rvalue_reg = self.box(rvalue_reg, rvalue_type)
            self.add(
                SetAttr(target.obj_reg, target.attr, rvalue_reg,
                        target.obj_type))
            return INVALID_REGISTER
        elif isinstance(target, AssignmentTargetIndex):
            item_reg = self.accept(rvalue)
            boxed_item_reg = self.box(item_reg, rvalue_type)
            if isinstance(target.rtype, ListRType):
                op = PrimitiveOp.LIST_SET
            elif isinstance(target.rtype, DictRType):
                op = PrimitiveOp.DICT_SET
            else:
                assert False, target.rtype
            self.add(
                PrimitiveOp(None, op, target.base_reg, target.index_reg,
                            boxed_item_reg))
            return INVALID_REGISTER

        assert False, 'Unsupported assignment target'

    def assign(self, lvalue: Lvalue, rvalue: Expression, rvalue_type: RType,
               lvalue_type: RType, declare_new: bool) -> Register:
        target = self.get_assignment_target(lvalue, declare_new)
        needs_box = rvalue_type.supports_unbox and not lvalue_type.supports_unbox
        return self.assign_to_target(target, rvalue, rvalue_type, needs_box)

    def visit_if_stmt(self, stmt: IfStmt) -> Register:
        # If statements are normalized
        assert len(stmt.expr) == 1

        branches = self.process_conditional(stmt.expr[0])
        if_body = self.new_block()
        self.set_branches(branches, True, if_body)
        stmt.body[0].accept(self)
        if_leave = self.add_leave()
        if stmt.else_body:
            else_body = self.new_block()
            self.set_branches(branches, False, else_body)
            stmt.else_body.accept(self)
            else_leave = self.add_leave()
            next = self.new_block()
            if else_leave:
                else_leave.label = next.label
        else:
            # No else block.
            next = self.new_block()
            self.set_branches(branches, False, next)
        if if_leave:
            if_leave.label = next.label
        return INVALID_REGISTER

    def add_leave(self) -> Optional[Goto]:
        if not self.blocks[-1][-1].ops or not isinstance(
                self.blocks[-1][-1].ops[-1], Return):
            leave = Goto(INVALID_LABEL)
            self.add(leave)
            return leave
        return None

    def push_loop_stack(self) -> None:
        self.break_gotos.append([])
        self.continue_gotos.append([])

    def pop_loop_stack(self, continue_block: BasicBlock,
                       break_block: BasicBlock) -> None:
        for continue_goto in self.continue_gotos.pop():
            continue_goto.label = continue_block.label

        for break_goto in self.break_gotos.pop():
            break_goto.label = break_block.label

    def visit_while_stmt(self, s: WhileStmt) -> Register:
        self.push_loop_stack()

        # Split block so that we get a handle to the top of the loop.
        goto = Goto(INVALID_LABEL)
        self.add(goto)
        top = self.new_block()
        goto.label = top.label
        branches = self.process_conditional(s.expr)

        body = self.new_block()
        # Bind "true" branches to the body block.
        self.set_branches(branches, True, body)
        s.body.accept(self)
        # Add branch to the top at the end of the body.
        self.add(Goto(top.label))
        next = self.new_block()
        # Bind "false" branches to the new block.
        self.set_branches(branches, False, next)

        self.pop_loop_stack(top, next)
        return INVALID_REGISTER

    def visit_for_stmt(self, s: ForStmt) -> Register:
        if (isinstance(s.expr, CallExpr)
                and isinstance(s.expr.callee, RefExpr)
                and s.expr.callee.fullname == 'builtins.range'):
            self.push_loop_stack()

            # Special case for x in range(...)
            # TODO: Check argument counts and kinds; check the lvalue
            end = s.expr.args[0]
            end_reg = self.accept(end)

            # Initialize loop index to 0.
            index_reg = self.assign(s.index,
                                    IntExpr(0),
                                    IntRType(),
                                    IntRType(),
                                    declare_new=True)
            goto = Goto(INVALID_LABEL)
            self.add(goto)

            # Add loop condition check.
            top = self.new_block()
            goto.label = top.label
            branch = Branch(index_reg, end_reg, INVALID_LABEL, INVALID_LABEL,
                            Branch.INT_LT)
            self.add(branch)
            branches = [branch]

            body = self.new_block()
            self.set_branches(branches, True, body)
            s.body.accept(self)

            end_goto = Goto(INVALID_LABEL)
            self.add(end_goto)
            end_block = self.new_block()
            end_goto.label = end_block.label

            # Increment index register.
            one_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(one_reg, 1))
            self.add(
                PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg,
                            one_reg))

            # Go back to loop condition check.
            self.add(Goto(top.label))
            next = self.new_block()
            self.set_branches(branches, False, next)

            self.pop_loop_stack(end_block, next)
            return INVALID_REGISTER

        if self.node_type(s.expr).name == 'list':
            self.push_loop_stack()

            expr_reg = self.accept(s.expr)

            index_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(index_reg, 0))

            one_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(one_reg, 1))

            assert isinstance(s.index, NameExpr)
            assert isinstance(s.index.node, Var)
            lvalue_reg = self.environment.add_local(s.index.node,
                                                    self.node_type(s.index))

            condition_block = self.goto_new_block()

            # For compatibility with python semantics we recalculate the length
            # at every iteration.
            len_reg = self.alloc_temp(IntRType())
            self.add(PrimitiveOp(len_reg, PrimitiveOp.LIST_LEN, expr_reg))

            branch = Branch(index_reg, len_reg, INVALID_LABEL, INVALID_LABEL,
                            Branch.INT_LT)
            self.add(branch)
            branches = [branch]

            body_block = self.new_block()
            self.set_branches(branches, True, body_block)

            target_list_type = self.types[s.expr]
            assert isinstance(target_list_type, Instance)
            target_type = self.type_to_rtype(target_list_type.args[0])
            value_box = self.alloc_temp(ObjectRType())
            self.add(
                PrimitiveOp(value_box, PrimitiveOp.LIST_GET, expr_reg,
                            index_reg))

            self.unbox_or_cast(value_box, target_type, target=lvalue_reg)

            s.body.accept(self)

            end_block = self.goto_new_block()
            self.add(
                PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg,
                            one_reg))
            self.add(Goto(condition_block.label))

            next_block = self.new_block()
            self.set_branches(branches, False, next_block)

            self.pop_loop_stack(end_block, next_block)

            return INVALID_REGISTER

        assert False, 'for not supported'

    def visit_break_stmt(self, node: BreakStmt) -> Register:
        self.break_gotos[-1].append(Goto(INVALID_LABEL))
        self.add(self.break_gotos[-1][-1])
        return INVALID_REGISTER

    def visit_continue_stmt(self, node: ContinueStmt) -> Register:
        self.continue_gotos[-1].append(Goto(INVALID_LABEL))
        self.add(self.continue_gotos[-1][-1])
        return INVALID_REGISTER

    int_binary_ops = {
        '+': PrimitiveOp.INT_ADD,
        '-': PrimitiveOp.INT_SUB,
        '*': PrimitiveOp.INT_MUL,
        '//': PrimitiveOp.INT_DIV,
        '%': PrimitiveOp.INT_MOD,
        '&': PrimitiveOp.INT_AND,
        '|': PrimitiveOp.INT_OR,
        '^': PrimitiveOp.INT_XOR,
        '<<': PrimitiveOp.INT_SHL,
        '>>': PrimitiveOp.INT_SHR,
        '>>': PrimitiveOp.INT_SHR,
    }

    def visit_unary_expr(self, expr: UnaryExpr) -> Register:
        if expr.op != '-':
            assert False, 'Unsupported unary operation'

        etype = self.node_type(expr.expr)
        reg = self.accept(expr.expr)
        if etype.name != 'int':
            assert False, 'Unsupported unary operation'

        target = self.alloc_target(IntRType())
        zero = self.accept(IntExpr(0))
        self.add(PrimitiveOp(target, PrimitiveOp.INT_SUB, zero, reg))

        return target

    def visit_op_expr(self, expr: OpExpr) -> Register:
        ltype = self.node_type(expr.left)
        rtype = self.node_type(expr.right)
        lreg = self.accept(expr.left)
        rreg = self.accept(expr.right)
        return self.binary_op(ltype, lreg, rtype, rreg, expr.op)

    def binary_op(self,
                  ltype: RType,
                  lreg: Register,
                  rtype: RType,
                  rreg: Register,
                  expr_op: str,
                  target: Optional[Register] = None) -> Register:
        if ltype.name == 'int' and rtype.name == 'int':
            # Primitive int operation
            if target is None:
                target = self.alloc_target(IntRType())
            op = self.int_binary_ops[expr_op]
        elif (ltype.name == 'list' or rtype.name == 'list') and expr_op == '*':
            if rtype.name == 'list':
                ltype, rtype = rtype, ltype
                lreg, rreg = rreg, lreg
            if rtype.name != 'int':
                assert False, 'Unsupported binary operation'  # TODO: Operator overloading
            if target is None:
                target = self.alloc_target(ListRType())
            op = PrimitiveOp.LIST_REPEAT
        elif isinstance(rtype, DictRType):
            if expr_op == 'in':
                if target is None:
                    target = self.alloc_target(BoolRType())
                lreg = self.box(lreg, ltype)
                op = PrimitiveOp.DICT_CONTAINS
            else:
                assert False, 'Unsupported binary operation'
        else:
            assert False, 'Unsupported binary operation'
        self.add(PrimitiveOp(target, op, lreg, rreg))
        return target

    def visit_index_expr(self, expr: IndexExpr) -> Register:
        base_rtype = self.node_type(expr.base)
        base_reg = self.accept(expr.base)
        target_type = self.node_type(expr)

        if isinstance(base_rtype, (ListRType, SequenceTupleRType, DictRType)):
            index_type = self.node_type(expr.index)
            if not isinstance(base_rtype, DictRType):
                assert isinstance(
                    index_type,
                    IntRType), 'Unsupported indexing operation'  # TODO
            if isinstance(base_rtype, ListRType):
                op = PrimitiveOp.LIST_GET
            elif isinstance(base_rtype, DictRType):
                op = PrimitiveOp.DICT_GET
            else:
                op = PrimitiveOp.HOMOGENOUS_TUPLE_GET
            index_reg = self.accept(expr.index)
            if isinstance(base_rtype, DictRType):
                index_reg = self.box(index_reg, index_type)
            tmp = self.alloc_temp(ObjectRType())
            self.add(PrimitiveOp(tmp, op, base_reg, index_reg))
            target = self.alloc_target(target_type)
            return self.unbox_or_cast(tmp, target_type, target)
        elif isinstance(base_rtype, TupleRType):
            assert isinstance(expr.index, IntExpr)  # TODO
            target = self.alloc_target(target_type)
            self.add(
                TupleGet(target, base_reg, expr.index.value,
                         base_rtype.types[expr.index.value]))
            return target

        assert False, 'Unsupported indexing operation'

    def visit_int_expr(self, expr: IntExpr) -> Register:
        reg = self.alloc_target(IntRType())
        self.add(LoadInt(reg, expr.value))
        return reg

    def is_native_name_expr(self, expr: NameExpr) -> bool:
        # TODO later we want to support cross-module native calls too
        if '.' in expr.node.fullname():
            module_name = '.'.join(expr.node.fullname().split('.')[:-1])
            return module_name == self.current_module_name

        return True

    def visit_name_expr(self, expr: NameExpr) -> Register:
        if expr.node.fullname() == 'builtins.None':
            target = self.alloc_target(NoneRType())
            self.add(PrimitiveOp(target, PrimitiveOp.NONE))
            return target
        elif expr.node.fullname() == 'builtins.True':
            target = self.alloc_target(BoolRType())
            self.add(PrimitiveOp(target, PrimitiveOp.TRUE))
            return target
        elif expr.node.fullname() == 'builtins.False':
            target = self.alloc_target(BoolRType())
            self.add(PrimitiveOp(target, PrimitiveOp.FALSE))
            return target

        if not self.is_native_name_expr(expr):
            return self.load_static_module_attr(expr)

        # TODO: We assume that this is a Var node, which is very limited
        assert isinstance(expr.node, Var)

        reg = self.environment.lookup(expr.node)
        return self.get_using_binder(reg, expr.node, expr)

    def get_using_binder(self, reg: Register, var: Var,
                         expr: Expression) -> Register:
        var_type = self.type_to_rtype(var.type)
        target_type = self.node_type(expr)
        if var_type != target_type:
            # Cast/unbox to the narrower given by the binder.
            if self.targets[-1] < 0:
                target = self.alloc_temp(target_type)
            else:
                target = self.targets[-1]
            return self.unbox_or_cast(reg, target_type, target)
        else:
            # Regular register access -- binder is not active.
            if self.targets[-1] < 0:
                return reg
            else:
                target = self.targets[-1]
                self.add(Assign(target, reg))
                return target

    def is_module_member_expr(self, expr: MemberExpr):
        return isinstance(expr.expr, RefExpr) and expr.expr.kind == MODULE_REF

    def visit_member_expr(self, expr: MemberExpr) -> Register:
        if self.is_module_member_expr(expr):
            return self.load_static_module_attr(expr)

        else:
            obj_reg = self.accept(expr.expr)
            attr_type = self.node_type(expr)
            target = self.alloc_target(attr_type)
            obj_type = self.node_type(expr.expr)
            assert isinstance(
                obj_type,
                UserRType), 'Attribute access not supported: %s' % obj_type
            self.add(GetAttr(target, obj_reg, expr.name, obj_type))
            return target

    def load_static_module_attr(self, expr: RefExpr) -> Register:
        target = self.alloc_target(self.node_type(expr))
        module = '.'.join(expr.node.fullname().split('.')[:-1])
        right = expr.node.fullname().split('.')[-1]
        left = self.alloc_temp(ObjectRType())
        self.add(LoadStatic(left, c_module_name(module)))
        self.add(PyGetAttr(target, left, right))

        return target

    def py_call(self, function: Register, args: List[Expression],
                target_type: RType) -> Register:
        target_box = self.alloc_temp(ObjectRType())

        arg_boxes = []  # type: List[Register]
        for arg_expr in args:
            arg_reg = self.accept(arg_expr)
            arg_boxes.append(self.box(arg_reg, self.node_type(arg_expr)))

        self.add(PyCall(target_box, function, arg_boxes))
        return self.unbox_or_cast(target_box, target_type)

    def visit_call_expr(self, expr: CallExpr) -> Register:
        if isinstance(expr.callee, MemberExpr):
            is_module_call = self.is_module_member_expr(expr.callee)
            if expr.callee.expr in self.types and not is_module_call:
                target = self.translate_special_method_call(expr.callee, expr)
                if target:
                    return target

            # Either its a module call or translating to a special method call failed, so we have
            # to fallback to a PyCall
            function = self.accept(expr.callee)
            return self.py_call(function, expr.args, self.node_type(expr))

        assert isinstance(expr.callee, NameExpr)
        fn = expr.callee.name  # TODO: fullname
        if fn == 'len' and len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
            target = self.alloc_target(IntRType())
            arg = self.accept(expr.args[0])

            expr_rtype = self.node_type(expr.args[0])
            if expr_rtype.name == 'list':
                self.add(PrimitiveOp(target, PrimitiveOp.LIST_LEN, arg))
            elif expr_rtype.name == 'sequence_tuple':
                self.add(
                    PrimitiveOp(target, PrimitiveOp.HOMOGENOUS_TUPLE_LEN, arg))
            elif isinstance(expr_rtype, TupleRType):
                self.add(LoadInt(target, len(expr_rtype.types)))
            else:
                assert False, "unsupported use of len"

        # Handle conversion to sequence tuple
        elif fn == 'tuple' and len(
                expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
            target = self.alloc_target(SequenceTupleRType())
            arg = self.accept(expr.args[0])

            self.add(
                PrimitiveOp(target, PrimitiveOp.LIST_TO_HOMOGENOUS_TUPLE, arg))
        else:
            target_type = self.node_type(expr)
            if not (self.is_native_name_expr(expr.callee)):
                function = self.accept(expr.callee)
                return self.py_call(function, expr.args, target_type)

            target = self.alloc_target(target_type)
            args = [self.accept(arg) for arg in expr.args]
            self.add(Call(target, fn, args))
        return target

    def visit_conditional_expr(self, expr: ConditionalExpr) -> Register:
        branches = self.process_conditional(expr.cond)
        target = self.alloc_target(self.node_type(expr))

        if_body = self.new_block()
        self.set_branches(branches, True, if_body)
        self.accept(expr.if_expr, target=target)
        if_goto_next = Goto(INVALID_LABEL)
        self.add(if_goto_next)

        else_body = self.new_block()
        self.set_branches(branches, False, else_body)
        self.accept(expr.else_expr, target=target)
        else_goto_next = Goto(INVALID_LABEL)
        self.add(else_goto_next)

        next = self.new_block()
        if_goto_next.label = next.label
        else_goto_next.label = next.label

        return target

    def translate_special_method_call(self, callee: MemberExpr,
                                      expr: CallExpr) -> Register:
        base_type = self.node_type(callee.expr)
        result_type = self.node_type(expr)
        base = self.accept(callee.expr)
        if callee.name == 'append' and base_type.name == 'list':
            target = INVALID_REGISTER  # TODO: Do we sometimes need to allocate a register?
            arg = self.box_expr(expr.args[0])
            self.add(PrimitiveOp(target, PrimitiveOp.LIST_APPEND, base, arg))
        else:
            assert False, 'Unsupported method call: %s.%s' % (base_type.name,
                                                              callee.name)
        return target

    def visit_list_expr(self, expr: ListExpr) -> Register:
        list_type = self.types[expr]
        assert isinstance(list_type, Instance)
        item_type = self.type_to_rtype(list_type.args[0])
        target = self.alloc_target(ListRType())
        items = []
        for item in expr.items:
            item_reg = self.accept(item)
            boxed = self.box(item_reg, item_type)
            items.append(boxed)
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_LIST, *items))
        return target

    def visit_tuple_expr(self, expr: TupleExpr) -> Register:
        tuple_type = self.types[expr]
        assert isinstance(tuple_type, TupleType)

        target = self.alloc_target(self.type_to_rtype(tuple_type))
        items = [self.accept(i) for i in expr.items]
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_TUPLE, *items))
        return target

    def visit_dict_expr(self, expr: DictExpr):
        assert not expr.items  # TODO
        target = self.alloc_target(DictRType())
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_DICT))
        return target

    # Conditional expressions

    int_relative_ops = {
        '==': Branch.INT_EQ,
        '!=': Branch.INT_NE,
        '<': Branch.INT_LT,
        '<=': Branch.INT_LE,
        '>': Branch.INT_GT,
        '>=': Branch.INT_GE,
    }

    def process_conditional(self, e: Node) -> List[Branch]:
        if isinstance(e, ComparisonExpr):
            # TODO: Verify operand types.
            assert len(e.operators) == 1, 'more than 1 operator not supported'
            op = e.operators[0]
            if op in ['==', '!=', '<', '<=', '>', '>=']:
                # TODO: check operand types
                left = self.accept(e.operands[0])
                right = self.accept(e.operands[1])
                opcode = self.int_relative_ops[op]
                branch = Branch(left, right, INVALID_LABEL, INVALID_LABEL,
                                opcode)
            elif op in ['is', 'is not']:
                # TODO: check if right operand is None
                left = self.accept(e.operands[0])
                branch = Branch(left, INVALID_REGISTER, INVALID_LABEL,
                                INVALID_LABEL, Branch.IS_NONE)
                if op == 'is not':
                    branch.negated = True
            elif op in ['in', 'not in']:
                left = self.accept(e.operands[0])
                ltype = self.node_type(e.operands[0])
                right = self.accept(e.operands[1])
                rtype = self.node_type(e.operands[1])
                target = self.alloc_temp(self.node_type(e))
                self.binary_op(ltype, left, rtype, right, 'in', target=target)
                branch = Branch(target, INVALID_REGISTER, INVALID_LABEL,
                                INVALID_LABEL, Branch.BOOL_EXPR)
                if op == 'not in':
                    branch.negated = True
            else:
                assert False, "unsupported comparison epxression"
            self.add(branch)
            return [branch]
        elif isinstance(e, OpExpr) and e.op in ['and', 'or']:
            if e.op == 'and':
                # Short circuit 'and' in a conditional context.
                lbranches = self.process_conditional(e.left)
                new = self.new_block()
                self.set_branches(lbranches, True, new)
                rbranches = self.process_conditional(e.right)
                return lbranches + rbranches
            else:
                # Short circuit 'or' in a conditional context.
                lbranches = self.process_conditional(e.left)
                new = self.new_block()
                self.set_branches(lbranches, False, new)
                rbranches = self.process_conditional(e.right)
                return lbranches + rbranches
        elif isinstance(e, UnaryExpr) and e.op == 'not':
            branches = self.process_conditional(e.expr)
            for b in branches:
                b.invert()
            return branches
        # Catch-all for arbitrary expressions.
        else:
            reg = self.accept(e)
            branch = Branch(reg, INVALID_REGISTER, INVALID_LABEL,
                            INVALID_LABEL, Branch.BOOL_EXPR)
            self.add(branch)
            return [branch]

    def set_branches(self, branches: List[Branch], condition: bool,
                     target: BasicBlock) -> None:
        """Set branch targets for the given condition (True or False).

        If the target has already been set for a branch, skip the branch.
        """
        for b in branches:
            if condition:
                if b.true < 0:
                    b.true = target.label
            else:
                if b.false < 0:
                    b.false = target.label

    # Helpers

    def enter(self) -> None:
        self.environment = Environment()
        self.environments.append(self.environment)
        self.blocks.append([])
        self.new_block()

    def new_block(self) -> BasicBlock:
        new = BasicBlock(Label(len(self.blocks[-1])))
        self.blocks[-1].append(new)
        return new

    def goto_new_block(self) -> BasicBlock:
        goto = Goto(INVALID_LABEL)
        self.add(goto)
        block = self.new_block()
        goto.label = block.label
        return block

    def leave(self) -> Tuple[List[BasicBlock], Environment]:
        blocks = self.blocks.pop()
        env = self.environments.pop()
        self.environment = self.environments[-1]
        return blocks, env

    def add(self, op: Op) -> None:
        self.blocks[-1][-1].ops.append(op)

    def accept(self,
               node: Node,
               target: Register = INVALID_REGISTER) -> Register:
        self.targets.append(target)
        actual = node.accept(self)
        self.targets.pop()
        return actual

    def alloc_target(self, type: RType) -> Register:
        if self.targets[-1] < 0:
            return self.environment.add_temp(type)
        else:
            return self.targets[-1]

    def alloc_temp(self, type: RType) -> Register:
        return self.environment.add_temp(type)

    def type_to_rtype(self, typ: Type) -> RType:
        return self.mapper.type_to_rtype(typ)

    def node_type(self, node: Expression) -> RType:
        mypy_type = self.types[node]
        return self.type_to_rtype(mypy_type)

    def box(self,
            src: Register,
            typ: RType,
            target: Optional[Register] = None) -> Register:
        if typ.supports_unbox:
            if target is None:
                target = self.alloc_temp(ObjectRType())
            self.add(Box(target, src, typ))
            return target
        else:
            # Already boxed
            if target is not None:
                self.add(Assign(target, src))
                return target
            else:
                return src

    def unbox_or_cast(self,
                      src: Register,
                      target_type: RType,
                      target: Optional[Register] = None) -> Register:
        if target is None:
            target = self.alloc_temp(target_type)

        if target_type.supports_unbox:
            self.add(Unbox(target, src, target_type))
        else:
            self.add(Cast(target, src, target_type))
        return target

    def box_expr(self, expr: Expression) -> Register:
        typ = self.node_type(expr)
        return self.box(self.accept(expr), typ)
Пример #7
0
class TestFunctionEmitterVisitor(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), IntRType())
        self.m = self.env.add_local(Var('m'), IntRType())
        self.k = self.env.add_local(Var('k'), IntRType())
        self.l = self.env.add_local(Var('l'), ListRType())
        self.ll = self.env.add_local(Var('ll'), ListRType())
        self.o = self.env.add_local(Var('o'), ObjectRType())
        self.o2 = self.env.add_local(Var('o2'), ObjectRType())
        self.d = self.env.add_local(Var('d'), DictRType())
        self.b = self.env.add_local(Var('b'), BoolRType())
        self.context = EmitterContext()
        self.emitter = Emitter(self.context, self.env)
        self.declarations = Emitter(self.context, self.env)
        self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations)

    def test_goto(self) -> None:
        self.assert_emit(Goto(Label(2)), "goto CPyL2;")

    def test_return(self) -> None:
        self.assert_emit(Return(self.m), "return cpy_r_m;")

    def test_load_int(self) -> None:
        self.assert_emit(LoadInt(self.m, 5), "cpy_r_m = 10;")

    def test_tuple_get(self) -> None:
        self.assert_emit(TupleGet(self.m, self.n, 1, BoolRType()),
                         'cpy_r_m = cpy_r_n.f1;')

    def test_load_None(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.m, PrimitiveOp.NONE), """cpy_r_m = Py_None;
                            Py_INCREF(cpy_r_m);
                         """)

    def test_load_True(self) -> None:
        self.assert_emit(PrimitiveOp(self.m, PrimitiveOp.TRUE), "cpy_r_m = 1;")

    def test_load_False(self) -> None:
        self.assert_emit(PrimitiveOp(self.m, PrimitiveOp.FALSE),
                         "cpy_r_m = 0;")

    def test_assign_int(self) -> None:
        self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;")

    def test_int_add(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.n, PrimitiveOp.INT_ADD, self.m, self.k),
            "cpy_r_n = CPyTagged_Add(cpy_r_m, cpy_r_k);")

    def test_int_sub(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.n, PrimitiveOp.INT_SUB, self.m, self.k),
            "cpy_r_n = CPyTagged_Subtract(cpy_r_m, cpy_r_k);")

    def test_list_repeat(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.ll, PrimitiveOp.LIST_REPEAT, self.l, self.n),
            """long long __tmp1;
                            __tmp1 = CPyTagged_AsLongLong(cpy_r_n);
                            if (__tmp1 == -1 && PyErr_Occurred())
                                abort();
                            cpy_r_ll = PySequence_Repeat(cpy_r_l, __tmp1);
                            if (!cpy_r_ll)
                                abort();
                         """)

    def test_int_neg(self) -> None:
        self.assert_emit(PrimitiveOp(self.n, PrimitiveOp.INT_NEG, self.m),
                         "cpy_r_n = CPy_NegateInt(cpy_r_m);")

    def test_list_len(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.n, PrimitiveOp.LIST_LEN, self.l),
            """long long __tmp1;
                            __tmp1 = PyList_GET_SIZE(cpy_r_l);
                            cpy_r_n = CPyTagged_ShortFromLongLong(__tmp1);
                         """)

    def test_branch_eq(self) -> None:
        self.assert_emit(
            Branch(self.n, self.m, Label(8), Label(9), Branch.INT_EQ),
            """if (CPyTagged_IsEq(cpy_r_n, cpy_r_m))
                                goto CPyL8;
                            else
                                goto CPyL9;
                         """)
        b = Branch(self.n, self.m, Label(8), Label(9), Branch.INT_LT)
        b.negated = True
        self.assert_emit(
            b, """if (!CPyTagged_IsLt(cpy_r_n, cpy_r_m))
                                goto CPyL8;
                            else
                                goto CPyL9;
                         """)

    def test_call(self) -> None:
        self.assert_emit(Call(self.n, 'myfn', [self.m]),
                         "cpy_r_n = CPyDef_myfn(cpy_r_m);")

    def test_call_two_args(self) -> None:
        self.assert_emit(Call(self.n, 'myfn', [self.m, self.k]),
                         "cpy_r_n = CPyDef_myfn(cpy_r_m, cpy_r_k);")

    def test_call_no_return(self) -> None:
        self.assert_emit(Call(None, 'myfn', [self.m, self.k]),
                         "CPyDef_myfn(cpy_r_m, cpy_r_k);")

    def test_inc_ref(self) -> None:
        self.assert_emit(IncRef(self.m, IntRType()),
                         "CPyTagged_IncRef(cpy_r_m);")

    def test_dec_ref(self) -> None:
        self.assert_emit(DecRef(self.m, IntRType()),
                         "CPyTagged_DecRef(cpy_r_m);")

    def test_dec_ref_tuple(self) -> None:
        tuple_type = TupleRType([IntRType(), BoolRType()])
        self.assert_emit(DecRef(self.m, tuple_type),
                         'CPyTagged_DecRef(cpy_r_m.f0);')

    def test_dec_ref_tuple_nested(self) -> None:
        tuple_type = TupleRType(
            [TupleRType([IntRType(), BoolRType()]),
             BoolRType()])
        self.assert_emit(DecRef(self.m, tuple_type),
                         'CPyTagged_DecRef(cpy_r_m.f0.f0);')

    def test_list_get_item(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.n, PrimitiveOp.LIST_GET, self.m, self.k),
            """cpy_r_n = CPyList_GetItem(cpy_r_m, cpy_r_k);
                            if (!cpy_r_n)
                                abort();
                         """)

    def test_list_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp(None, PrimitiveOp.LIST_SET, self.l, self.n, self.o),
            """if (!CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o))
                                abort();
                         """)

    def test_box(self) -> None:
        self.assert_emit(Box(self.o, self.n, IntRType()),
                         """cpy_r_o = CPyTagged_StealAsObject(cpy_r_n);""")

    def test_unbox(self) -> None:
        self.assert_emit(
            Unbox(self.n, self.m, IntRType()), """if (PyLong_Check(cpy_r_m))
                                cpy_r_n = CPyTagged_FromObject(cpy_r_m);
                            else
                                abort();
                         """)

    def test_new_list(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.l, PrimitiveOp.NEW_LIST, self.n, self.m),
            """cpy_r_l = PyList_New(2);
                            Py_INCREF(cpy_r_n);
                            PyList_SET_ITEM(cpy_r_l, 0, cpy_r_n);
                            Py_INCREF(cpy_r_m);
                            PyList_SET_ITEM(cpy_r_l, 1, cpy_r_m);
                         """)

    def test_list_append(self) -> None:
        self.assert_emit(
            PrimitiveOp(None, PrimitiveOp.LIST_APPEND, self.l, self.o),
            """if (PyList_Append(cpy_r_l, cpy_r_o) == -1)
                                abort();
                         """)

    def test_get_attr(self) -> None:
        ir = ClassIR('A', [('x', BoolRType()), ('y', IntRType())])
        rtype = UserRType(ir)
        self.assert_emit(
            GetAttr(self.n, self.m, 'y', rtype),
            """cpy_r_n = CPY_GET_ATTR(cpy_r_m, 2, AObject, CPyTagged);""")

    def test_set_attr(self) -> None:
        ir = ClassIR('A', [('x', BoolRType()), ('y', IntRType())])
        rtype = UserRType(ir)
        self.assert_emit(
            SetAttr(self.n, 'y', self.m, rtype),
            """CPY_SET_ATTR(cpy_r_n, 3, cpy_r_m, AObject, CPyTagged);""")

    def test_dict_get_item(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.o, PrimitiveOp.DICT_GET, self.d, self.o2),
            """cpy_r_o = PyDict_GetItem(cpy_r_d, cpy_r_o2);
                            if (!cpy_r_o)
                                abort();
                            Py_INCREF(cpy_r_o);
                         """)

    def test_dict_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp(None, PrimitiveOp.DICT_SET, self.d, self.o, self.o2),
            """if (PyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2) < 0)
                                abort();
                         """)

    def test_new_dict(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.d, PrimitiveOp.NEW_DICT),
            """cpy_r_d = PyDict_New();
                            if (!cpy_r_d)
                                abort();
                         """)

    def test_dict_contains(self) -> None:
        self.assert_emit(
            PrimitiveOp(self.b, PrimitiveOp.DICT_CONTAINS, self.o, self.d),
            """int __tmp1 = PyDict_Contains(cpy_r_d, cpy_r_o);
                            if (__tmp1 < 0)
                                abort();
                            cpy_r_b = __tmp1;
                         """)

    def assert_emit(self, op: Op, expected: str) -> None:
        self.emitter.fragments = []
        self.declarations.fragments = []
        op.accept(self.visitor)
        frags = self.declarations.fragments + self.emitter.fragments
        actual_lines = [line.strip(' ') for line in frags]
        assert all(line.endswith('\n') for line in actual_lines)
        actual_lines = [line.rstrip('\n') for line in actual_lines]
        expected_lines = expected.rstrip().split('\n')
        expected_lines = [line.strip(' ') for line in expected_lines]
        assert_string_arrays_equal(expected_lines,
                                   actual_lines,
                                   msg='Generated code unexpected')