Пример #1
0
def transform_block(block: BasicBlock, pre_live: AnalysisDict[Register],
                    post_live: AnalysisDict[Register],
                    pre_borrow: AnalysisDict[Register],
                    env: Environment) -> None:
    old_ops = block.ops
    ops = []  # type: List[Op]
    for i, op in enumerate(old_ops):
        key = (block.label, i)
        if isinstance(op, (Assign, Cast, Box)):
            # These operations just copy/steal a reference and don't create new
            # references.
            if op.src in post_live[key] or op.src in pre_borrow[key]:
                ops.append(IncRef(op.src, env.types[op.src]))
                if (op.dest not in pre_borrow[key]
                        and op.dest in pre_live[key]):
                    ops.append(DecRef(op.dest, env.types[op.dest]))
            ops.append(op)
            if op.dest not in post_live[key]:
                ops.append(DecRef(op.dest, env.types[op.dest]))
        elif isinstance(op, RegisterOp):
            # These operations construct a new reference.
            tmp_reg = None  # type: Optional[Register]
            if (op.dest not in pre_borrow[key] and op.dest in pre_live[key]):
                if op.dest not in op.sources():
                    ops.append(DecRef(op.dest, env.types[op.dest]))
                else:
                    tmp_reg = env.add_temp(env.types[op.dest])
                    ops.append(Assign(tmp_reg, op.dest))
            ops.append(op)
            for src in op.unique_sources():
                # Decrement source that won't be live afterwards.
                if src not in post_live[key] and src not in pre_borrow[key]:
                    if src != op.dest:
                        ops.append(DecRef(src, env.types[src]))
            if op.dest is not None and op.dest not in post_live[key]:
                ops.append(DecRef(op.dest, env.types[op.dest]))
            if tmp_reg is not None:
                ops.append(DecRef(tmp_reg, env.types[tmp_reg]))
        elif isinstance(op, Return) and op.reg in pre_borrow[key]:
            # The return op returns a new reference.
            ops.append(IncRef(op.reg, env.types[op.reg]))
            ops.append(op)
        else:
            ops.append(op)
    block.ops = ops
Пример #2
0
class TestGenerateFunction(unittest.TestCase):
    def setUp(self) -> None:
        self.var = Var('arg')
        self.arg = RuntimeArg('arg', IntRType())
        self.env = Environment()
        self.reg = self.env.add_local(self.var, IntRType())
        self.block = BasicBlock(Label(0))

    def test_simple(self) -> None:
        self.block.ops.append(Return(self.reg))
        fn = FuncIR('myfunc', [self.arg], IntRType(), [self.block], self.env)
        emitter = Emitter(EmitterContext())
        generate_native_function(fn, emitter)
        result = emitter.fragments
        assert_string_arrays_equal([
            'static CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            'CPyL0: ;\n',
            '    return cpy_r_arg;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')

    def test_register(self) -> None:
        self.temp = self.env.add_temp(IntRType())
        self.block.ops.append(LoadInt(self.temp, 5))
        fn = FuncIR('myfunc', [self.arg], ListRType(), [self.block], self.env)
        emitter = Emitter(EmitterContext())
        generate_native_function(fn, emitter)
        result = emitter.fragments
        assert_string_arrays_equal([
            'static PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            '    CPyTagged cpy_r_r0;\n',
            'CPyL0: ;\n',
            '    cpy_r_r0 = 10;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')
Пример #3
0
class LowLevelIRBuilder:
    def __init__(
        self,
        current_module: str,
        mapper: Mapper,
    ) -> None:
        self.current_module = current_module
        self.mapper = mapper
        self.environment = Environment()
        self.blocks = []  # type: List[BasicBlock]
        # Stack of except handler entry blocks
        self.error_handlers = [None]  # type: List[Optional[BasicBlock]]

    def add(self, op: Op) -> Value:
        assert not self.blocks[-1].terminated, "Can't add to finished block"

        self.blocks[-1].ops.append(op)
        if isinstance(op, RegisterOp):
            self.environment.add_op(op)
        return op

    def goto(self, target: BasicBlock) -> None:
        if not self.blocks[-1].terminated:
            self.add(Goto(target))

    def activate_block(self, block: BasicBlock) -> None:
        if self.blocks:
            assert self.blocks[-1].terminated

        block.error_handler = self.error_handlers[-1]
        self.blocks.append(block)

    def goto_and_activate(self, block: BasicBlock) -> None:
        self.goto(block)
        self.activate_block(block)

    def push_error_handler(self, handler: Optional[BasicBlock]) -> None:
        self.error_handlers.append(handler)

    def pop_error_handler(self) -> Optional[BasicBlock]:
        return self.error_handlers.pop()

    ##

    def get_native_type(self, cls: ClassIR) -> Value:
        fullname = '%s.%s' % (cls.module_name, cls.name)
        return self.load_native_type_object(fullname)

    def primitive_op(self, desc: OpDescription, args: List[Value],
                     line: int) -> Value:
        assert desc.result_type is not None
        coerced = []
        for i, arg in enumerate(args):
            formal_type = self.op_arg_type(desc, i)
            arg = self.coerce(arg, formal_type, line)
            coerced.append(arg)
        target = self.add(PrimitiveOp(coerced, desc, line))
        return target

    def alloc_temp(self, type: RType) -> Register:
        return self.environment.add_temp(type)

    def op_arg_type(self, desc: OpDescription, n: int) -> RType:
        if n >= len(desc.arg_types):
            assert desc.is_var_arg
            return desc.arg_types[-1]
        return desc.arg_types[n]

    def box(self, src: Value) -> Value:
        if src.type.is_unboxed:
            return self.add(Box(src))
        else:
            return src

    def unbox_or_cast(self, src: Value, target_type: RType,
                      line: int) -> Value:
        if target_type.is_unboxed:
            return self.add(Unbox(src, target_type, line))
        else:
            return self.add(Cast(src, target_type, line))

    def coerce(self,
               src: Value,
               target_type: RType,
               line: int,
               force: bool = False) -> Value:
        """Generate a coercion/cast from one type to other (only if needed).

        For example, int -> object boxes the source int; int -> int emits nothing;
        object -> int unboxes the object. All conversions preserve object value.

        If force is true, always generate an op (even if it is just an assignment) so
        that the result will have exactly target_type as the type.

        Returns the register with the converted value (may be same as src).
        """
        if src.type.is_unboxed and not target_type.is_unboxed:
            return self.box(src)
        if ((src.type.is_unboxed and target_type.is_unboxed)
                and not is_runtime_subtype(src.type, target_type)):
            # To go from one unboxed type to another, we go through a boxed
            # in-between value, for simplicity.
            tmp = self.box(src)
            return self.unbox_or_cast(tmp, target_type, line)
        if ((not src.type.is_unboxed and target_type.is_unboxed)
                or not is_subtype(src.type, target_type)):
            return self.unbox_or_cast(src, target_type, line)
        elif force:
            tmp = self.alloc_temp(target_type)
            self.add(Assign(tmp, src))
            return tmp
        return src

    def none(self) -> Value:
        return self.add(PrimitiveOp([], none_op, line=-1))

    def none_object(self) -> Value:
        return self.add(PrimitiveOp([], none_object_op, line=-1))

    def get_attr(self, obj: Value, attr: str, result_type: RType,
                 line: int) -> Value:
        if (isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class
                and obj.type.class_ir.has_attr(attr)):
            return self.add(GetAttr(obj, attr, line))
        elif isinstance(obj.type, RUnion):
            return self.union_get_attr(obj, obj.type, attr, result_type, line)
        else:
            return self.py_get_attr(obj, attr, line)

    def union_get_attr(self, obj: Value, rtype: RUnion, attr: str,
                       result_type: RType, line: int) -> Value:
        def get_item_attr(value: Value) -> Value:
            return self.get_attr(value, attr, result_type, line)

        return self.decompose_union_helper(obj, rtype, result_type,
                                           get_item_attr, line)

    def decompose_union_helper(self, obj: Value, rtype: RUnion,
                               result_type: RType,
                               process_item: Callable[[Value], Value],
                               line: int) -> Value:
        """Generate isinstance() + specialized operations for union items.

        Say, for Union[A, B] generate ops resembling this (pseudocode):

            if isinstance(obj, A):
                result = <result of process_item(cast(A, obj)>
            else:
                result = <result of process_item(cast(B, obj)>

        Args:
            obj: value with a union type
            rtype: the union type
            result_type: result of the operation
            process_item: callback to generate op for a single union item (arg is coerced
                to union item type)
            line: line number
        """
        # TODO: Optimize cases where a single operation can handle multiple union items
        #     (say a method is implemented in a common base class)
        fast_items = []
        rest_items = []
        for item in rtype.items:
            if isinstance(item, RInstance):
                fast_items.append(item)
            else:
                # For everything but RInstance we fall back to C API
                rest_items.append(item)
        exit_block = BasicBlock()
        result = self.alloc_temp(result_type)
        for i, item in enumerate(fast_items):
            more_types = i < len(fast_items) - 1 or rest_items
            if more_types:
                # We are not at the final item so we need one more branch
                op = self.isinstance_native(obj, item.class_ir, line)
                true_block, false_block = BasicBlock(), BasicBlock()
                self.add_bool_branch(op, true_block, false_block)
                self.activate_block(true_block)
            coerced = self.coerce(obj, item, line)
            temp = process_item(coerced)
            temp2 = self.coerce(temp, result_type, line)
            self.add(Assign(result, temp2))
            self.goto(exit_block)
            if more_types:
                self.activate_block(false_block)
        if rest_items:
            # For everything else we use generic operation. Use force=True to drop the
            # union type.
            coerced = self.coerce(obj, object_rprimitive, line, force=True)
            temp = process_item(coerced)
            temp2 = self.coerce(temp, result_type, line)
            self.add(Assign(result, temp2))
            self.goto(exit_block)
        self.activate_block(exit_block)
        return result

    def isinstance_helper(self, obj: Value, class_irs: List[ClassIR],
                          line: int) -> Value:
        """Fast path for isinstance() that checks against a list of native classes."""
        if not class_irs:
            return self.primitive_op(false_op, [], line)
        ret = self.isinstance_native(obj, class_irs[0], line)
        for class_ir in class_irs[1:]:

            def other() -> Value:
                return self.isinstance_native(obj, class_ir, line)

            ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret,
                                           other, line)
        return ret

    def isinstance_native(self, obj: Value, class_ir: ClassIR,
                          line: int) -> Value:
        """Fast isinstance() check for a native class.

        If there three or less concrete (non-trait) classes among the class and all
        its children, use even faster type comparison checks `type(obj) is typ`.
        """
        concrete = all_concrete_classes(class_ir)
        if concrete is None or len(
                concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1:
            return self.primitive_op(fast_isinstance_op,
                                     [obj, self.get_native_type(class_ir)],
                                     line)
        if not concrete:
            # There can't be any concrete instance that matches this.
            return self.primitive_op(false_op, [], line)
        type_obj = self.get_native_type(concrete[0])
        ret = self.primitive_op(type_is_op, [obj, type_obj], line)
        for c in concrete[1:]:

            def other() -> Value:
                return self.primitive_op(type_is_op,
                                         [obj, self.get_native_type(c)], line)

            ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret,
                                           other, line)
        return ret

    def py_get_attr(self, obj: Value, attr: str, line: int) -> Value:
        key = self.load_static_unicode(attr)
        return self.add(PrimitiveOp([obj, key], py_getattr_op, line))

    def py_call(self,
                function: Value,
                arg_values: List[Value],
                line: int,
                arg_kinds: Optional[List[int]] = None,
                arg_names: Optional[Sequence[Optional[str]]] = None) -> Value:
        """Use py_call_op or py_call_with_kwargs_op for function call."""
        # If all arguments are positional, we can use py_call_op.
        if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds):
            return self.primitive_op(py_call_op, [function] + arg_values, line)

        # Otherwise fallback to py_call_with_kwargs_op.
        assert arg_names is not None

        pos_arg_values = []
        kw_arg_key_value_pairs = []  # type: List[DictEntry]
        star_arg_values = []
        for value, kind, name in zip(arg_values, arg_kinds, arg_names):
            if kind == ARG_POS:
                pos_arg_values.append(value)
            elif kind == ARG_NAMED:
                assert name is not None
                key = self.load_static_unicode(name)
                kw_arg_key_value_pairs.append((key, value))
            elif kind == ARG_STAR:
                star_arg_values.append(value)
            elif kind == ARG_STAR2:
                # NOTE: mypy currently only supports a single ** arg, but python supports multiple.
                # This code supports multiple primarily to make the logic easier to follow.
                kw_arg_key_value_pairs.append((None, value))
            else:
                assert False, ("Argument kind should not be possible:", kind)

        if len(star_arg_values) == 0:
            # We can directly construct a tuple if there are no star args.
            pos_args_tuple = self.primitive_op(new_tuple_op, pos_arg_values,
                                               line)
        else:
            # Otherwise we construct a list and call extend it with the star args, since tuples
            # don't have an extend method.
            pos_args_list = self.primitive_op(new_list_op, pos_arg_values,
                                              line)
            for star_arg_value in star_arg_values:
                self.primitive_op(list_extend_op,
                                  [pos_args_list, star_arg_value], line)
            pos_args_tuple = self.primitive_op(list_tuple_op, [pos_args_list],
                                               line)

        kw_args_dict = self.make_dict(kw_arg_key_value_pairs, line)

        return self.primitive_op(py_call_with_kwargs_op,
                                 [function, pos_args_tuple, kw_args_dict],
                                 line)

    def py_method_call(self, obj: Value, method_name: str,
                       arg_values: List[Value], line: int,
                       arg_kinds: Optional[List[int]],
                       arg_names: Optional[Sequence[Optional[str]]]) -> Value:
        if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds):
            method_name_reg = self.load_static_unicode(method_name)
            return self.primitive_op(py_method_call_op,
                                     [obj, method_name_reg] + arg_values, line)
        else:
            method = self.py_get_attr(obj, method_name, line)
            return self.py_call(method,
                                arg_values,
                                line,
                                arg_kinds=arg_kinds,
                                arg_names=arg_names)

    def call(self, decl: FuncDecl, args: Sequence[Value], arg_kinds: List[int],
             arg_names: Sequence[Optional[str]], line: int) -> Value:
        # Normalize args to positionals.
        args = self.native_args_to_positional(args, arg_kinds, arg_names,
                                              decl.sig, line)
        return self.add(Call(decl, args, line))

    def native_args_to_positional(self, args: Sequence[Value],
                                  arg_kinds: List[int],
                                  arg_names: Sequence[Optional[str]],
                                  sig: FuncSignature,
                                  line: int) -> List[Value]:
        """Prepare arguments for a native call.

        Given args/kinds/names and a target signature for a native call, map
        keyword arguments to their appropriate place in the argument list,
        fill in error values for unspecified default arguments,
        package arguments that will go into *args/**kwargs into a tuple/dict,
        and coerce arguments to the appropriate type.
        """

        sig_arg_kinds = [arg.kind for arg in sig.args]
        sig_arg_names = [arg.name for arg in sig.args]
        formal_to_actual = map_actuals_to_formals(
            arg_kinds, arg_names, sig_arg_kinds, sig_arg_names,
            lambda n: AnyType(TypeOfAny.special_form))

        # Flatten out the arguments, loading error values for default
        # arguments, constructing tuples/dicts for star args, and
        # coercing everything to the expected type.
        output_args = []
        for lst, arg in zip(formal_to_actual, sig.args):
            output_arg = None
            if arg.kind == ARG_STAR:
                output_arg = self.primitive_op(new_tuple_op,
                                               [args[i] for i in lst], line)
            elif arg.kind == ARG_STAR2:
                dict_entries = [
                    (self.load_static_unicode(cast(str,
                                                   arg_names[i])), args[i])
                    for i in lst
                ]
                output_arg = self.make_dict(dict_entries, line)
            elif not lst:
                output_arg = self.add(
                    LoadErrorValue(arg.type, is_borrowed=True))
            else:
                output_arg = args[lst[0]]
            output_args.append(self.coerce(output_arg, arg.type, line))

        return output_args

    def make_dict(self, key_value_pairs: Sequence[DictEntry],
                  line: int) -> Value:
        result = None  # type: Union[Value, None]
        initial_items = []  # type: List[Value]
        for key, value in key_value_pairs:
            if key is not None:
                # key:value
                if result is None:
                    initial_items.extend((key, value))
                    continue

                self.translate_special_method_call(result,
                                                   '__setitem__', [key, value],
                                                   result_type=None,
                                                   line=line)
            else:
                # **value
                if result is None:
                    result = self.primitive_op(new_dict_op, initial_items,
                                               line)

                self.primitive_op(dict_update_in_display_op, [result, value],
                                  line=line)

        if result is None:
            result = self.primitive_op(new_dict_op, initial_items, line)

        return result

    # Loading stuff
    def literal_static_name(
            self, value: Union[int, float, complex, str, bytes]) -> str:
        return self.mapper.literal_static_name(self.current_module, value)

    def load_static_int(self, value: int) -> Value:
        """Loads a static integer Python 'int' object into a register."""
        if abs(value) > MAX_LITERAL_SHORT_INT:
            static_symbol = self.literal_static_name(value)
            return self.add(
                LoadStatic(int_rprimitive, static_symbol, ann=value))
        else:
            return self.add(LoadInt(value))

    def load_static_float(self, value: float) -> Value:
        """Loads a static float value into a register."""
        static_symbol = self.literal_static_name(value)
        return self.add(LoadStatic(float_rprimitive, static_symbol, ann=value))

    def load_static_bytes(self, value: bytes) -> Value:
        """Loads a static bytes value into a register."""
        static_symbol = self.literal_static_name(value)
        return self.add(LoadStatic(object_rprimitive, static_symbol,
                                   ann=value))

    def load_static_complex(self, value: complex) -> Value:
        """Loads a static complex value into a register."""
        static_symbol = self.literal_static_name(value)
        return self.add(LoadStatic(object_rprimitive, static_symbol,
                                   ann=value))

    def load_static_unicode(self, value: str) -> Value:
        """Loads a static unicode value into a register.

        This is useful for more than just unicode literals; for example, method calls
        also require a PyObject * form for the name of the method.
        """
        static_symbol = self.literal_static_name(value)
        return self.add(LoadStatic(str_rprimitive, static_symbol, ann=value))

    def load_module(self, name: str) -> Value:
        return self.add(
            LoadStatic(object_rprimitive, name, namespace=NAMESPACE_MODULE))

    def load_native_type_object(self, fullname: str) -> Value:
        module, name = fullname.rsplit('.', 1)
        return self.add(
            LoadStatic(object_rprimitive, name, module, NAMESPACE_TYPE))

    def matching_primitive_op(
            self,
            candidates: List[OpDescription],
            args: List[Value],
            line: int,
            result_type: Optional[RType] = None) -> Optional[Value]:
        # Find the highest-priority primitive op that matches.
        matching = None  # type: Optional[OpDescription]
        for desc in candidates:
            if len(desc.arg_types) != len(args):
                continue
            if all(
                    is_subtype(actual.type, formal)
                    for actual, formal in zip(args, desc.arg_types)):
                if matching:
                    assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % (
                        matching, desc)
                    if desc.priority > matching.priority:
                        matching = desc
                else:
                    matching = desc
        if matching:
            target = self.primitive_op(matching, args, line)
            if result_type and not is_runtime_subtype(target.type,
                                                      result_type):
                if is_none_rprimitive(result_type):
                    # Special case None return. The actual result may actually be a bool
                    # and so we can't just coerce it.
                    target = self.none()
                else:
                    target = self.coerce(target, result_type, line)
            return target
        return None

    def binary_op(self, lreg: Value, rreg: Value, expr_op: str,
                  line: int) -> Value:
        # Special case == and != when we can resolve the method call statically.
        value = None
        if expr_op in ('==', '!='):
            value = self.translate_eq_cmp(lreg, rreg, expr_op, line)
        if value is not None:
            return value

        ops = binary_ops.get(expr_op, [])
        target = self.matching_primitive_op(ops, [lreg, rreg], line)
        assert target, 'Unsupported binary operation: %s' % expr_op
        return target

    def unary_op(self, lreg: Value, expr_op: str, line: int) -> Value:
        ops = unary_ops.get(expr_op, [])
        target = self.matching_primitive_op(ops, [lreg], line)
        assert target, 'Unsupported unary operation: %s' % expr_op
        return target

    def shortcircuit_helper(self, op: str, expr_type: RType,
                            left: Callable[[], Value],
                            right: Callable[[], Value], line: int) -> Value:
        # Having actual Phi nodes would be really nice here!
        target = self.alloc_temp(expr_type)
        # left_body takes the value of the left side, right_body the right
        left_body, right_body, next = BasicBlock(), BasicBlock(), BasicBlock()
        # true_body is taken if the left is true, false_body if it is false.
        # For 'and' the value is the right side if the left is true, and for 'or'
        # it is the right side if the left is false.
        true_body, false_body = ((right_body, left_body) if op == 'and' else
                                 (left_body, right_body))

        left_value = left()
        self.add_bool_branch(left_value, true_body, false_body)

        self.activate_block(left_body)
        left_coerced = self.coerce(left_value, expr_type, line)
        self.add(Assign(target, left_coerced))
        self.goto(next)

        self.activate_block(right_body)
        right_value = right()
        right_coerced = self.coerce(right_value, expr_type, line)
        self.add(Assign(target, right_coerced))
        self.goto(next)

        self.activate_block(next)
        return target

    def add_bool_branch(self, value: Value, true: BasicBlock,
                        false: BasicBlock) -> None:
        if is_runtime_subtype(value.type, int_rprimitive):
            zero = self.add(LoadInt(0))
            value = self.binary_op(value, zero, '!=', value.line)
        elif is_same_type(value.type, list_rprimitive):
            length = self.primitive_op(list_len_op, [value], value.line)
            zero = self.add(LoadInt(0))
            value = self.binary_op(length, zero, '!=', value.line)
        elif (isinstance(value.type, RInstance)
              and value.type.class_ir.is_ext_class
              and value.type.class_ir.has_method('__bool__')):
            # Directly call the __bool__ method on classes that have it.
            value = self.gen_method_call(value, '__bool__', [],
                                         bool_rprimitive, value.line)
        else:
            value_type = optional_value_type(value.type)
            if value_type is not None:
                is_none = self.binary_op(value, self.none_object(), 'is not',
                                         value.line)
                branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
                self.add(branch)
                always_truthy = False
                if isinstance(value_type, RInstance):
                    # check whether X.__bool__ is always just the default (object.__bool__)
                    if (not value_type.class_ir.has_method('__bool__') and
                            value_type.class_ir.is_method_final('__bool__')):
                        always_truthy = True

                if not always_truthy:
                    # Optional[X] where X may be falsey and requires a check
                    branch.true = BasicBlock()
                    self.activate_block(branch.true)
                    # unbox_or_cast instead of coerce because we want the
                    # type to change even if it is a subtype.
                    remaining = self.unbox_or_cast(value, value_type,
                                                   value.line)
                    self.add_bool_branch(remaining, true, false)
                return
            elif not is_same_type(value.type, bool_rprimitive):
                value = self.primitive_op(bool_op, [value], value.line)
        self.add(Branch(value, true, false, Branch.BOOL_EXPR))

    def translate_special_method_call(self, base_reg: Value, name: str,
                                      args: List[Value],
                                      result_type: Optional[RType],
                                      line: int) -> Optional[Value]:
        """Translate a method call which is handled nongenerically.

        These are special in the sense that we have code generated specifically for them.
        They tend to be method calls which have equivalents in C that are more direct
        than calling with the PyObject api.

        Return None if no translation found; otherwise return the target register.
        """
        ops = method_ops.get(name, [])
        return self.matching_primitive_op(ops, [base_reg] + args,
                                          line,
                                          result_type=result_type)

    def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str,
                         line: int) -> Optional[Value]:
        ltype = lreg.type
        rtype = rreg.type
        if not (isinstance(ltype, RInstance) and ltype == rtype):
            return None

        class_ir = ltype.class_ir
        # Check whether any subclasses of the operand redefines __eq__
        # or it might be redefined in a Python parent class or by
        # dataclasses
        cmp_varies_at_runtime = (not class_ir.is_method_final('__eq__')
                                 or not class_ir.is_method_final('__ne__')
                                 or class_ir.inherits_python
                                 or class_ir.is_augmented)

        if cmp_varies_at_runtime:
            # We might need to call left.__eq__(right) or right.__eq__(left)
            # depending on which is the more specific type.
            return None

        if not class_ir.has_method('__eq__'):
            # There's no __eq__ defined, so just use object identity.
            identity_ref_op = 'is' if expr_op == '==' else 'is not'
            return self.binary_op(lreg, rreg, identity_ref_op, line)

        return self.gen_method_call(lreg, op_methods[expr_op], [rreg], ltype,
                                    line)

    def gen_method_call(
            self,
            base: Value,
            name: str,
            arg_values: List[Value],
            result_type: Optional[RType],
            line: int,
            arg_kinds: Optional[List[int]] = None,
            arg_names: Optional[List[Optional[str]]] = None) -> Value:
        # If arg_kinds contains values other than arg_pos and arg_named, then fallback to
        # Python method call.
        if (arg_kinds is not None and not all(kind in (ARG_POS, ARG_NAMED)
                                              for kind in arg_kinds)):
            return self.py_method_call(base, name, arg_values, base.line,
                                       arg_kinds, arg_names)

        # If the base type is one of ours, do a MethodCall
        if (isinstance(base.type, RInstance)
                and base.type.class_ir.is_ext_class
                and not base.type.class_ir.builtin_base):
            if base.type.class_ir.has_method(name):
                decl = base.type.class_ir.method_decl(name)
                if arg_kinds is None:
                    assert arg_names is None, "arg_kinds not present but arg_names is"
                    arg_kinds = [ARG_POS for _ in arg_values]
                    arg_names = [None for _ in arg_values]
                else:
                    assert arg_names is not None, "arg_kinds present but arg_names is not"

                # Normalize args to positionals.
                assert decl.bound_sig
                arg_values = self.native_args_to_positional(
                    arg_values, arg_kinds, arg_names, decl.bound_sig, line)
                return self.add(MethodCall(base, name, arg_values, line))
            elif base.type.class_ir.has_attr(name):
                function = self.add(GetAttr(base, name, line))
                return self.py_call(function,
                                    arg_values,
                                    line,
                                    arg_kinds=arg_kinds,
                                    arg_names=arg_names)

        elif isinstance(base.type, RUnion):
            return self.union_method_call(base, base.type, name, arg_values,
                                          result_type, line, arg_kinds,
                                          arg_names)

        # Try to do a special-cased method call
        if not arg_kinds or arg_kinds == [ARG_POS] * len(arg_values):
            target = self.translate_special_method_call(
                base, name, arg_values, result_type, line)
            if target:
                return target

        # Fall back to Python method call
        return self.py_method_call(base, name, arg_values, line, arg_kinds,
                                   arg_names)

    def union_method_call(self, base: Value, obj_type: RUnion, name: str,
                          arg_values: List[Value],
                          return_rtype: Optional[RType], line: int,
                          arg_kinds: Optional[List[int]],
                          arg_names: Optional[List[Optional[str]]]) -> Value:
        # Union method call needs a return_rtype for the type of the output register.
        # If we don't have one, use object_rprimitive.
        return_rtype = return_rtype or object_rprimitive

        def call_union_item(value: Value) -> Value:
            return self.gen_method_call(value, name, arg_values, return_rtype,
                                        line, arg_kinds, arg_names)

        return self.decompose_union_helper(base, obj_type, return_rtype,
                                           call_union_item, line)
Пример #4
0
class IRBuilder(NodeVisitor[Register]):
    def __init__(self, types: Dict[Expression, Type], mapper: Mapper) -> None:
        self.types = types
        self.environment = Environment()
        self.environments = [self.environment]
        self.blocks = []  # type: List[List[BasicBlock]]
        self.functions = []  # type: List[FuncIR]
        self.classes = []  # type: List[ClassIR]
        self.targets = []  # type: List[Register]

        # These lists operate as stack frames for loops. Each loop adds a new
        # frame (i.e. adds a new empty list [] to the outermost list). Each
        # break or continue is inserted within that frame as they are visited
        # and at the end of the loop the stack is popped and any break/continue
        # gotos have their targets rewritten to the next basic block.
        self.break_gotos = []  # type: List[List[Goto]]
        self.continue_gotos = []  # type: List[List[Goto]]

        self.mapper = mapper
        self.imports = []  # type: List[str]

        self.current_module_name = None  # type: Optional[str]

    def visit_mypy_file(self, mypyfile: MypyFile) -> Register:
        if mypyfile.fullname() in ('typing', 'abc'):
            # These module are special; their contents are currently all
            # built-in primitives.
            return INVALID_REGISTER

        # First pass: Build ClassIRs and TypeInfo-to-ClassIR mapping.
        for node in mypyfile.defs:
            if isinstance(node, ClassDef):
                self.prepare_class_def(node)

        # Second pass: Generate ops.
        self.current_module_name = mypyfile.fullname()
        for node in mypyfile.defs:
            node.accept(self)

        return INVALID_REGISTER

    def prepare_class_def(self, cdef: ClassDef) -> None:
        ir = ClassIR(cdef.name,
                     [])  # Populate attributes later in visit_class_def
        self.classes.append(ir)
        self.mapper.type_to_ir[cdef.info] = ir

    def visit_class_def(self, cdef: ClassDef) -> Register:
        attributes = []
        for name, node in cdef.info.names.items():
            if isinstance(node.node, Var):
                attributes.append((name, self.type_to_rtype(node.node.type)))
        ir = self.mapper.type_to_ir[cdef.info]
        ir.attributes = attributes
        return INVALID_REGISTER

    def visit_import(self, node: Import) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        for node_id, _ in node.ids:
            self.imports.append(node_id)

        return INVALID_REGISTER

    def visit_import_from(self, node: ImportFrom) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        self.imports.append(node.id)

        return INVALID_REGISTER

    def visit_import_all(self, node: ImportAll) -> Register:
        if node.is_unreachable or node.is_mypy_only:
            pass
        if not node.is_top_level:
            assert False, "non-toplevel imports not supported"

        self.imports.append(node.id)

        return INVALID_REGISTER

    def visit_func_def(self, fdef: FuncDef) -> Register:
        self.enter()

        for arg in fdef.arguments:
            self.environment.add_local(arg.variable,
                                       self.type_to_rtype(arg.variable.type))
        fdef.body.accept(self)

        ret_type = self.convert_return_type(fdef)
        if ret_type.name == 'None':
            self.add_implicit_return()
        else:
            self.add_implicit_unreachable()

        blocks, env = self.leave()
        args = self.convert_args(fdef)
        func = FuncIR(fdef.name(), args, ret_type, blocks, env)
        self.functions.append(func)
        return INVALID_REGISTER

    def convert_args(self, fdef: FuncDef) -> List[RuntimeArg]:
        assert isinstance(fdef.type, CallableType)
        ann = fdef.type
        return [
            RuntimeArg(arg.variable.name(),
                       self.type_to_rtype(ann.arg_types[i]))
            for i, arg in enumerate(fdef.arguments)
        ]

    def convert_return_type(self, fdef: FuncDef) -> RType:
        assert isinstance(fdef.type, CallableType)
        return self.type_to_rtype(fdef.type.ret_type)

    def add_implicit_return(self) -> None:
        block = self.blocks[-1][-1]
        if not block.ops or not isinstance(block.ops[-1], Return):
            retval = self.environment.add_temp(NoneRType())
            self.add(PrimitiveOp(retval, PrimitiveOp.NONE))
            self.add(Return(retval))

    def add_implicit_unreachable(self) -> None:
        block = self.blocks[-1][-1]
        if not block.ops or not isinstance(block.ops[-1], Return):
            self.add(Unreachable())

    def visit_block(self, block: Block) -> Register:
        for stmt in block.body:
            stmt.accept(self)
        return INVALID_REGISTER

    def visit_expression_stmt(self, stmt: ExpressionStmt) -> Register:
        self.accept(stmt.expr)
        return INVALID_REGISTER

    def visit_return_stmt(self, stmt: ReturnStmt) -> Register:
        if stmt.expr:
            retval = self.accept(stmt.expr)
        else:
            retval = self.environment.add_temp(NoneRType())
            self.add(PrimitiveOp(retval, PrimitiveOp.NONE))
        self.add(Return(retval))
        return INVALID_REGISTER

    def visit_assignment_stmt(self, stmt: AssignmentStmt) -> Register:
        assert len(stmt.lvalues) == 1
        lvalue = stmt.lvalues[0]
        if stmt.type:
            lvalue_type = self.type_to_rtype(stmt.type)
        else:
            if isinstance(lvalue, IndexExpr):
                # TODO: This won't be right for user-defined classes. Store the
                #     lvalue type in mypy and remove this special case.
                lvalue_type = ObjectRType()
            else:
                lvalue_type = self.node_type(lvalue)
        rvalue_type = self.node_type(stmt.rvalue)
        return self.assign(lvalue,
                           stmt.rvalue,
                           rvalue_type,
                           lvalue_type,
                           declare_new=(stmt.type is not None))

    def visit_operator_assignment_stmt(
            self, stmt: OperatorAssignmentStmt) -> Register:
        target = self.get_assignment_target(stmt.lvalue, declare_new=False)

        if isinstance(target, AssignmentTargetRegister):
            ltype = self.environment.types[target.register]
            rtype = self.node_type(stmt.rvalue)
            rreg = self.accept(stmt.rvalue)
            return self.binary_op(ltype,
                                  target.register,
                                  rtype,
                                  rreg,
                                  stmt.op,
                                  target=target.register)

        # NOTE: List index not supported yet for compound assignments.
        assert False, 'Unsupported lvalue: %r'

    def get_assignment_target(self, lvalue: Lvalue,
                              declare_new: bool) -> AssignmentTarget:
        if isinstance(lvalue, NameExpr):
            # Assign to local variable.
            assert lvalue.kind == LDEF
            if lvalue.is_def or declare_new:
                # Define a new variable.
                assert isinstance(lvalue.node, Var)  # TODO: Can this fail?
                lvalue_num = self.environment.add_local(
                    lvalue.node, self.node_type(lvalue))
            else:
                # Assign to a previously defined variable.
                assert isinstance(lvalue.node, Var)  # TODO: Can this fail?
                lvalue_num = self.environment.lookup(lvalue.node)

            return AssignmentTargetRegister(lvalue_num)
        elif isinstance(lvalue, IndexExpr):
            # Indexed assignment x[y] = e
            base_type = self.node_type(lvalue.base)
            index_type = self.node_type(lvalue.index)
            base_reg = self.accept(lvalue.base)
            index_reg = self.accept(lvalue.index)
            if isinstance(base_type, ListRType) and isinstance(
                    index_type, IntRType):
                # Indexed list set
                return AssignmentTargetIndex(base_reg, index_reg, base_type)
            elif isinstance(base_type, DictRType):
                # Indexed dict set
                boxed_index = self.box(index_reg, index_type)
                return AssignmentTargetIndex(base_reg, boxed_index, base_type)
        elif isinstance(lvalue, MemberExpr):
            # Attribute assignment x.y = e
            obj_type = self.node_type(lvalue.expr)
            assert isinstance(
                obj_type,
                UserRType), 'Attribute set only supported for user types'
            obj_reg = self.accept(lvalue.expr)
            return AssignmentTargetAttr(obj_reg, lvalue.name, obj_type)

        assert False, 'Unsupported lvalue: %r' % lvalue

    def assign_to_target(self, target: AssignmentTarget, rvalue: Expression,
                         rvalue_type: RType, needs_box: bool) -> Register:
        rvalue_type = rvalue_type or self.node_type(rvalue)

        if isinstance(target, AssignmentTargetRegister):
            if needs_box:
                unboxed = self.accept(rvalue)
                return self.box(unboxed, rvalue_type, target=target.register)
            else:
                return self.accept(rvalue, target=target.register)
        elif isinstance(target, AssignmentTargetAttr):
            rvalue_reg = self.accept(rvalue)
            if needs_box:
                rvalue_reg = self.box(rvalue_reg, rvalue_type)
            self.add(
                SetAttr(target.obj_reg, target.attr, rvalue_reg,
                        target.obj_type))
            return INVALID_REGISTER
        elif isinstance(target, AssignmentTargetIndex):
            item_reg = self.accept(rvalue)
            boxed_item_reg = self.box(item_reg, rvalue_type)
            if isinstance(target.rtype, ListRType):
                op = PrimitiveOp.LIST_SET
            elif isinstance(target.rtype, DictRType):
                op = PrimitiveOp.DICT_SET
            else:
                assert False, target.rtype
            self.add(
                PrimitiveOp(None, op, target.base_reg, target.index_reg,
                            boxed_item_reg))
            return INVALID_REGISTER

        assert False, 'Unsupported assignment target'

    def assign(self, lvalue: Lvalue, rvalue: Expression, rvalue_type: RType,
               lvalue_type: RType, declare_new: bool) -> Register:
        target = self.get_assignment_target(lvalue, declare_new)
        needs_box = rvalue_type.supports_unbox and not lvalue_type.supports_unbox
        return self.assign_to_target(target, rvalue, rvalue_type, needs_box)

    def visit_if_stmt(self, stmt: IfStmt) -> Register:
        # If statements are normalized
        assert len(stmt.expr) == 1

        branches = self.process_conditional(stmt.expr[0])
        if_body = self.new_block()
        self.set_branches(branches, True, if_body)
        stmt.body[0].accept(self)
        if_leave = self.add_leave()
        if stmt.else_body:
            else_body = self.new_block()
            self.set_branches(branches, False, else_body)
            stmt.else_body.accept(self)
            else_leave = self.add_leave()
            next = self.new_block()
            if else_leave:
                else_leave.label = next.label
        else:
            # No else block.
            next = self.new_block()
            self.set_branches(branches, False, next)
        if if_leave:
            if_leave.label = next.label
        return INVALID_REGISTER

    def add_leave(self) -> Optional[Goto]:
        if not self.blocks[-1][-1].ops or not isinstance(
                self.blocks[-1][-1].ops[-1], Return):
            leave = Goto(INVALID_LABEL)
            self.add(leave)
            return leave
        return None

    def push_loop_stack(self) -> None:
        self.break_gotos.append([])
        self.continue_gotos.append([])

    def pop_loop_stack(self, continue_block: BasicBlock,
                       break_block: BasicBlock) -> None:
        for continue_goto in self.continue_gotos.pop():
            continue_goto.label = continue_block.label

        for break_goto in self.break_gotos.pop():
            break_goto.label = break_block.label

    def visit_while_stmt(self, s: WhileStmt) -> Register:
        self.push_loop_stack()

        # Split block so that we get a handle to the top of the loop.
        goto = Goto(INVALID_LABEL)
        self.add(goto)
        top = self.new_block()
        goto.label = top.label
        branches = self.process_conditional(s.expr)

        body = self.new_block()
        # Bind "true" branches to the body block.
        self.set_branches(branches, True, body)
        s.body.accept(self)
        # Add branch to the top at the end of the body.
        self.add(Goto(top.label))
        next = self.new_block()
        # Bind "false" branches to the new block.
        self.set_branches(branches, False, next)

        self.pop_loop_stack(top, next)
        return INVALID_REGISTER

    def visit_for_stmt(self, s: ForStmt) -> Register:
        if (isinstance(s.expr, CallExpr)
                and isinstance(s.expr.callee, RefExpr)
                and s.expr.callee.fullname == 'builtins.range'):
            self.push_loop_stack()

            # Special case for x in range(...)
            # TODO: Check argument counts and kinds; check the lvalue
            end = s.expr.args[0]
            end_reg = self.accept(end)

            # Initialize loop index to 0.
            index_reg = self.assign(s.index,
                                    IntExpr(0),
                                    IntRType(),
                                    IntRType(),
                                    declare_new=True)
            goto = Goto(INVALID_LABEL)
            self.add(goto)

            # Add loop condition check.
            top = self.new_block()
            goto.label = top.label
            branch = Branch(index_reg, end_reg, INVALID_LABEL, INVALID_LABEL,
                            Branch.INT_LT)
            self.add(branch)
            branches = [branch]

            body = self.new_block()
            self.set_branches(branches, True, body)
            s.body.accept(self)

            end_goto = Goto(INVALID_LABEL)
            self.add(end_goto)
            end_block = self.new_block()
            end_goto.label = end_block.label

            # Increment index register.
            one_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(one_reg, 1))
            self.add(
                PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg,
                            one_reg))

            # Go back to loop condition check.
            self.add(Goto(top.label))
            next = self.new_block()
            self.set_branches(branches, False, next)

            self.pop_loop_stack(end_block, next)
            return INVALID_REGISTER

        if self.node_type(s.expr).name == 'list':
            self.push_loop_stack()

            expr_reg = self.accept(s.expr)

            index_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(index_reg, 0))

            one_reg = self.alloc_temp(IntRType())
            self.add(LoadInt(one_reg, 1))

            assert isinstance(s.index, NameExpr)
            assert isinstance(s.index.node, Var)
            lvalue_reg = self.environment.add_local(s.index.node,
                                                    self.node_type(s.index))

            condition_block = self.goto_new_block()

            # For compatibility with python semantics we recalculate the length
            # at every iteration.
            len_reg = self.alloc_temp(IntRType())
            self.add(PrimitiveOp(len_reg, PrimitiveOp.LIST_LEN, expr_reg))

            branch = Branch(index_reg, len_reg, INVALID_LABEL, INVALID_LABEL,
                            Branch.INT_LT)
            self.add(branch)
            branches = [branch]

            body_block = self.new_block()
            self.set_branches(branches, True, body_block)

            target_list_type = self.types[s.expr]
            assert isinstance(target_list_type, Instance)
            target_type = self.type_to_rtype(target_list_type.args[0])
            value_box = self.alloc_temp(ObjectRType())
            self.add(
                PrimitiveOp(value_box, PrimitiveOp.LIST_GET, expr_reg,
                            index_reg))

            self.unbox_or_cast(value_box, target_type, target=lvalue_reg)

            s.body.accept(self)

            end_block = self.goto_new_block()
            self.add(
                PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg,
                            one_reg))
            self.add(Goto(condition_block.label))

            next_block = self.new_block()
            self.set_branches(branches, False, next_block)

            self.pop_loop_stack(end_block, next_block)

            return INVALID_REGISTER

        assert False, 'for not supported'

    def visit_break_stmt(self, node: BreakStmt) -> Register:
        self.break_gotos[-1].append(Goto(INVALID_LABEL))
        self.add(self.break_gotos[-1][-1])
        return INVALID_REGISTER

    def visit_continue_stmt(self, node: ContinueStmt) -> Register:
        self.continue_gotos[-1].append(Goto(INVALID_LABEL))
        self.add(self.continue_gotos[-1][-1])
        return INVALID_REGISTER

    int_binary_ops = {
        '+': PrimitiveOp.INT_ADD,
        '-': PrimitiveOp.INT_SUB,
        '*': PrimitiveOp.INT_MUL,
        '//': PrimitiveOp.INT_DIV,
        '%': PrimitiveOp.INT_MOD,
        '&': PrimitiveOp.INT_AND,
        '|': PrimitiveOp.INT_OR,
        '^': PrimitiveOp.INT_XOR,
        '<<': PrimitiveOp.INT_SHL,
        '>>': PrimitiveOp.INT_SHR,
        '>>': PrimitiveOp.INT_SHR,
    }

    def visit_unary_expr(self, expr: UnaryExpr) -> Register:
        if expr.op != '-':
            assert False, 'Unsupported unary operation'

        etype = self.node_type(expr.expr)
        reg = self.accept(expr.expr)
        if etype.name != 'int':
            assert False, 'Unsupported unary operation'

        target = self.alloc_target(IntRType())
        zero = self.accept(IntExpr(0))
        self.add(PrimitiveOp(target, PrimitiveOp.INT_SUB, zero, reg))

        return target

    def visit_op_expr(self, expr: OpExpr) -> Register:
        ltype = self.node_type(expr.left)
        rtype = self.node_type(expr.right)
        lreg = self.accept(expr.left)
        rreg = self.accept(expr.right)
        return self.binary_op(ltype, lreg, rtype, rreg, expr.op)

    def binary_op(self,
                  ltype: RType,
                  lreg: Register,
                  rtype: RType,
                  rreg: Register,
                  expr_op: str,
                  target: Optional[Register] = None) -> Register:
        if ltype.name == 'int' and rtype.name == 'int':
            # Primitive int operation
            if target is None:
                target = self.alloc_target(IntRType())
            op = self.int_binary_ops[expr_op]
        elif (ltype.name == 'list' or rtype.name == 'list') and expr_op == '*':
            if rtype.name == 'list':
                ltype, rtype = rtype, ltype
                lreg, rreg = rreg, lreg
            if rtype.name != 'int':
                assert False, 'Unsupported binary operation'  # TODO: Operator overloading
            if target is None:
                target = self.alloc_target(ListRType())
            op = PrimitiveOp.LIST_REPEAT
        elif isinstance(rtype, DictRType):
            if expr_op == 'in':
                if target is None:
                    target = self.alloc_target(BoolRType())
                lreg = self.box(lreg, ltype)
                op = PrimitiveOp.DICT_CONTAINS
            else:
                assert False, 'Unsupported binary operation'
        else:
            assert False, 'Unsupported binary operation'
        self.add(PrimitiveOp(target, op, lreg, rreg))
        return target

    def visit_index_expr(self, expr: IndexExpr) -> Register:
        base_rtype = self.node_type(expr.base)
        base_reg = self.accept(expr.base)
        target_type = self.node_type(expr)

        if isinstance(base_rtype, (ListRType, SequenceTupleRType, DictRType)):
            index_type = self.node_type(expr.index)
            if not isinstance(base_rtype, DictRType):
                assert isinstance(
                    index_type,
                    IntRType), 'Unsupported indexing operation'  # TODO
            if isinstance(base_rtype, ListRType):
                op = PrimitiveOp.LIST_GET
            elif isinstance(base_rtype, DictRType):
                op = PrimitiveOp.DICT_GET
            else:
                op = PrimitiveOp.HOMOGENOUS_TUPLE_GET
            index_reg = self.accept(expr.index)
            if isinstance(base_rtype, DictRType):
                index_reg = self.box(index_reg, index_type)
            tmp = self.alloc_temp(ObjectRType())
            self.add(PrimitiveOp(tmp, op, base_reg, index_reg))
            target = self.alloc_target(target_type)
            return self.unbox_or_cast(tmp, target_type, target)
        elif isinstance(base_rtype, TupleRType):
            assert isinstance(expr.index, IntExpr)  # TODO
            target = self.alloc_target(target_type)
            self.add(
                TupleGet(target, base_reg, expr.index.value,
                         base_rtype.types[expr.index.value]))
            return target

        assert False, 'Unsupported indexing operation'

    def visit_int_expr(self, expr: IntExpr) -> Register:
        reg = self.alloc_target(IntRType())
        self.add(LoadInt(reg, expr.value))
        return reg

    def is_native_name_expr(self, expr: NameExpr) -> bool:
        # TODO later we want to support cross-module native calls too
        if '.' in expr.node.fullname():
            module_name = '.'.join(expr.node.fullname().split('.')[:-1])
            return module_name == self.current_module_name

        return True

    def visit_name_expr(self, expr: NameExpr) -> Register:
        if expr.node.fullname() == 'builtins.None':
            target = self.alloc_target(NoneRType())
            self.add(PrimitiveOp(target, PrimitiveOp.NONE))
            return target
        elif expr.node.fullname() == 'builtins.True':
            target = self.alloc_target(BoolRType())
            self.add(PrimitiveOp(target, PrimitiveOp.TRUE))
            return target
        elif expr.node.fullname() == 'builtins.False':
            target = self.alloc_target(BoolRType())
            self.add(PrimitiveOp(target, PrimitiveOp.FALSE))
            return target

        if not self.is_native_name_expr(expr):
            return self.load_static_module_attr(expr)

        # TODO: We assume that this is a Var node, which is very limited
        assert isinstance(expr.node, Var)

        reg = self.environment.lookup(expr.node)
        return self.get_using_binder(reg, expr.node, expr)

    def get_using_binder(self, reg: Register, var: Var,
                         expr: Expression) -> Register:
        var_type = self.type_to_rtype(var.type)
        target_type = self.node_type(expr)
        if var_type != target_type:
            # Cast/unbox to the narrower given by the binder.
            if self.targets[-1] < 0:
                target = self.alloc_temp(target_type)
            else:
                target = self.targets[-1]
            return self.unbox_or_cast(reg, target_type, target)
        else:
            # Regular register access -- binder is not active.
            if self.targets[-1] < 0:
                return reg
            else:
                target = self.targets[-1]
                self.add(Assign(target, reg))
                return target

    def is_module_member_expr(self, expr: MemberExpr):
        return isinstance(expr.expr, RefExpr) and expr.expr.kind == MODULE_REF

    def visit_member_expr(self, expr: MemberExpr) -> Register:
        if self.is_module_member_expr(expr):
            return self.load_static_module_attr(expr)

        else:
            obj_reg = self.accept(expr.expr)
            attr_type = self.node_type(expr)
            target = self.alloc_target(attr_type)
            obj_type = self.node_type(expr.expr)
            assert isinstance(
                obj_type,
                UserRType), 'Attribute access not supported: %s' % obj_type
            self.add(GetAttr(target, obj_reg, expr.name, obj_type))
            return target

    def load_static_module_attr(self, expr: RefExpr) -> Register:
        target = self.alloc_target(self.node_type(expr))
        module = '.'.join(expr.node.fullname().split('.')[:-1])
        right = expr.node.fullname().split('.')[-1]
        left = self.alloc_temp(ObjectRType())
        self.add(LoadStatic(left, c_module_name(module)))
        self.add(PyGetAttr(target, left, right))

        return target

    def py_call(self, function: Register, args: List[Expression],
                target_type: RType) -> Register:
        target_box = self.alloc_temp(ObjectRType())

        arg_boxes = []  # type: List[Register]
        for arg_expr in args:
            arg_reg = self.accept(arg_expr)
            arg_boxes.append(self.box(arg_reg, self.node_type(arg_expr)))

        self.add(PyCall(target_box, function, arg_boxes))
        return self.unbox_or_cast(target_box, target_type)

    def visit_call_expr(self, expr: CallExpr) -> Register:
        if isinstance(expr.callee, MemberExpr):
            is_module_call = self.is_module_member_expr(expr.callee)
            if expr.callee.expr in self.types and not is_module_call:
                target = self.translate_special_method_call(expr.callee, expr)
                if target:
                    return target

            # Either its a module call or translating to a special method call failed, so we have
            # to fallback to a PyCall
            function = self.accept(expr.callee)
            return self.py_call(function, expr.args, self.node_type(expr))

        assert isinstance(expr.callee, NameExpr)
        fn = expr.callee.name  # TODO: fullname
        if fn == 'len' and len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
            target = self.alloc_target(IntRType())
            arg = self.accept(expr.args[0])

            expr_rtype = self.node_type(expr.args[0])
            if expr_rtype.name == 'list':
                self.add(PrimitiveOp(target, PrimitiveOp.LIST_LEN, arg))
            elif expr_rtype.name == 'sequence_tuple':
                self.add(
                    PrimitiveOp(target, PrimitiveOp.HOMOGENOUS_TUPLE_LEN, arg))
            elif isinstance(expr_rtype, TupleRType):
                self.add(LoadInt(target, len(expr_rtype.types)))
            else:
                assert False, "unsupported use of len"

        # Handle conversion to sequence tuple
        elif fn == 'tuple' and len(
                expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
            target = self.alloc_target(SequenceTupleRType())
            arg = self.accept(expr.args[0])

            self.add(
                PrimitiveOp(target, PrimitiveOp.LIST_TO_HOMOGENOUS_TUPLE, arg))
        else:
            target_type = self.node_type(expr)
            if not (self.is_native_name_expr(expr.callee)):
                function = self.accept(expr.callee)
                return self.py_call(function, expr.args, target_type)

            target = self.alloc_target(target_type)
            args = [self.accept(arg) for arg in expr.args]
            self.add(Call(target, fn, args))
        return target

    def visit_conditional_expr(self, expr: ConditionalExpr) -> Register:
        branches = self.process_conditional(expr.cond)
        target = self.alloc_target(self.node_type(expr))

        if_body = self.new_block()
        self.set_branches(branches, True, if_body)
        self.accept(expr.if_expr, target=target)
        if_goto_next = Goto(INVALID_LABEL)
        self.add(if_goto_next)

        else_body = self.new_block()
        self.set_branches(branches, False, else_body)
        self.accept(expr.else_expr, target=target)
        else_goto_next = Goto(INVALID_LABEL)
        self.add(else_goto_next)

        next = self.new_block()
        if_goto_next.label = next.label
        else_goto_next.label = next.label

        return target

    def translate_special_method_call(self, callee: MemberExpr,
                                      expr: CallExpr) -> Register:
        base_type = self.node_type(callee.expr)
        result_type = self.node_type(expr)
        base = self.accept(callee.expr)
        if callee.name == 'append' and base_type.name == 'list':
            target = INVALID_REGISTER  # TODO: Do we sometimes need to allocate a register?
            arg = self.box_expr(expr.args[0])
            self.add(PrimitiveOp(target, PrimitiveOp.LIST_APPEND, base, arg))
        else:
            assert False, 'Unsupported method call: %s.%s' % (base_type.name,
                                                              callee.name)
        return target

    def visit_list_expr(self, expr: ListExpr) -> Register:
        list_type = self.types[expr]
        assert isinstance(list_type, Instance)
        item_type = self.type_to_rtype(list_type.args[0])
        target = self.alloc_target(ListRType())
        items = []
        for item in expr.items:
            item_reg = self.accept(item)
            boxed = self.box(item_reg, item_type)
            items.append(boxed)
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_LIST, *items))
        return target

    def visit_tuple_expr(self, expr: TupleExpr) -> Register:
        tuple_type = self.types[expr]
        assert isinstance(tuple_type, TupleType)

        target = self.alloc_target(self.type_to_rtype(tuple_type))
        items = [self.accept(i) for i in expr.items]
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_TUPLE, *items))
        return target

    def visit_dict_expr(self, expr: DictExpr):
        assert not expr.items  # TODO
        target = self.alloc_target(DictRType())
        self.add(PrimitiveOp(target, PrimitiveOp.NEW_DICT))
        return target

    # Conditional expressions

    int_relative_ops = {
        '==': Branch.INT_EQ,
        '!=': Branch.INT_NE,
        '<': Branch.INT_LT,
        '<=': Branch.INT_LE,
        '>': Branch.INT_GT,
        '>=': Branch.INT_GE,
    }

    def process_conditional(self, e: Node) -> List[Branch]:
        if isinstance(e, ComparisonExpr):
            # TODO: Verify operand types.
            assert len(e.operators) == 1, 'more than 1 operator not supported'
            op = e.operators[0]
            if op in ['==', '!=', '<', '<=', '>', '>=']:
                # TODO: check operand types
                left = self.accept(e.operands[0])
                right = self.accept(e.operands[1])
                opcode = self.int_relative_ops[op]
                branch = Branch(left, right, INVALID_LABEL, INVALID_LABEL,
                                opcode)
            elif op in ['is', 'is not']:
                # TODO: check if right operand is None
                left = self.accept(e.operands[0])
                branch = Branch(left, INVALID_REGISTER, INVALID_LABEL,
                                INVALID_LABEL, Branch.IS_NONE)
                if op == 'is not':
                    branch.negated = True
            elif op in ['in', 'not in']:
                left = self.accept(e.operands[0])
                ltype = self.node_type(e.operands[0])
                right = self.accept(e.operands[1])
                rtype = self.node_type(e.operands[1])
                target = self.alloc_temp(self.node_type(e))
                self.binary_op(ltype, left, rtype, right, 'in', target=target)
                branch = Branch(target, INVALID_REGISTER, INVALID_LABEL,
                                INVALID_LABEL, Branch.BOOL_EXPR)
                if op == 'not in':
                    branch.negated = True
            else:
                assert False, "unsupported comparison epxression"
            self.add(branch)
            return [branch]
        elif isinstance(e, OpExpr) and e.op in ['and', 'or']:
            if e.op == 'and':
                # Short circuit 'and' in a conditional context.
                lbranches = self.process_conditional(e.left)
                new = self.new_block()
                self.set_branches(lbranches, True, new)
                rbranches = self.process_conditional(e.right)
                return lbranches + rbranches
            else:
                # Short circuit 'or' in a conditional context.
                lbranches = self.process_conditional(e.left)
                new = self.new_block()
                self.set_branches(lbranches, False, new)
                rbranches = self.process_conditional(e.right)
                return lbranches + rbranches
        elif isinstance(e, UnaryExpr) and e.op == 'not':
            branches = self.process_conditional(e.expr)
            for b in branches:
                b.invert()
            return branches
        # Catch-all for arbitrary expressions.
        else:
            reg = self.accept(e)
            branch = Branch(reg, INVALID_REGISTER, INVALID_LABEL,
                            INVALID_LABEL, Branch.BOOL_EXPR)
            self.add(branch)
            return [branch]

    def set_branches(self, branches: List[Branch], condition: bool,
                     target: BasicBlock) -> None:
        """Set branch targets for the given condition (True or False).

        If the target has already been set for a branch, skip the branch.
        """
        for b in branches:
            if condition:
                if b.true < 0:
                    b.true = target.label
            else:
                if b.false < 0:
                    b.false = target.label

    # Helpers

    def enter(self) -> None:
        self.environment = Environment()
        self.environments.append(self.environment)
        self.blocks.append([])
        self.new_block()

    def new_block(self) -> BasicBlock:
        new = BasicBlock(Label(len(self.blocks[-1])))
        self.blocks[-1].append(new)
        return new

    def goto_new_block(self) -> BasicBlock:
        goto = Goto(INVALID_LABEL)
        self.add(goto)
        block = self.new_block()
        goto.label = block.label
        return block

    def leave(self) -> Tuple[List[BasicBlock], Environment]:
        blocks = self.blocks.pop()
        env = self.environments.pop()
        self.environment = self.environments[-1]
        return blocks, env

    def add(self, op: Op) -> None:
        self.blocks[-1][-1].ops.append(op)

    def accept(self,
               node: Node,
               target: Register = INVALID_REGISTER) -> Register:
        self.targets.append(target)
        actual = node.accept(self)
        self.targets.pop()
        return actual

    def alloc_target(self, type: RType) -> Register:
        if self.targets[-1] < 0:
            return self.environment.add_temp(type)
        else:
            return self.targets[-1]

    def alloc_temp(self, type: RType) -> Register:
        return self.environment.add_temp(type)

    def type_to_rtype(self, typ: Type) -> RType:
        return self.mapper.type_to_rtype(typ)

    def node_type(self, node: Expression) -> RType:
        mypy_type = self.types[node]
        return self.type_to_rtype(mypy_type)

    def box(self,
            src: Register,
            typ: RType,
            target: Optional[Register] = None) -> Register:
        if typ.supports_unbox:
            if target is None:
                target = self.alloc_temp(ObjectRType())
            self.add(Box(target, src, typ))
            return target
        else:
            # Already boxed
            if target is not None:
                self.add(Assign(target, src))
                return target
            else:
                return src

    def unbox_or_cast(self,
                      src: Register,
                      target_type: RType,
                      target: Optional[Register] = None) -> Register:
        if target is None:
            target = self.alloc_temp(target_type)

        if target_type.supports_unbox:
            self.add(Unbox(target, src, target_type))
        else:
            self.add(Cast(target, src, target_type))
        return target

    def box_expr(self, expr: Expression) -> Register:
        typ = self.node_type(expr)
        return self.box(self.accept(expr), typ)