Example #1
0
class TestGenerateFunction(unittest.TestCase):
    def setUp(self) -> None:
        self.var = Var('arg')
        self.arg = RuntimeArg('arg', int_rprimitive)
        self.env = Environment()
        self.reg = self.env.add_local(self.var, int_rprimitive)
        self.block = BasicBlock(0)

    def test_simple(self) -> None:
        self.block.ops.append(Return(self.reg))
        fn = FuncIR(
            FuncDecl('myfunc', None, 'mod',
                     FuncSignature([self.arg], int_rprimitive)), [self.block],
            self.env)
        value_names = generate_names_for_env(self.env)
        emitter = Emitter(EmitterContext(NameGenerator([['mod']])), self.env,
                          value_names)
        generate_native_function(fn,
                                 emitter,
                                 'prog.py',
                                 'prog',
                                 optimize_int=False)
        result = emitter.fragments
        assert_string_arrays_equal([
            'CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            'CPyL0: ;\n',
            '    return cpy_r_arg;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')

    def test_register(self) -> None:
        op = LoadInt(5)
        self.block.ops.append(op)
        self.env.add_op(op)
        fn = FuncIR(
            FuncDecl('myfunc', None, 'mod',
                     FuncSignature([self.arg], list_rprimitive)), [self.block],
            self.env)
        value_names = generate_names_for_env(self.env)
        emitter = Emitter(EmitterContext(NameGenerator([['mod']])), self.env,
                          value_names)
        generate_native_function(fn,
                                 emitter,
                                 'prog.py',
                                 'prog',
                                 optimize_int=False)
        result = emitter.fragments
        assert_string_arrays_equal([
            'PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n',
            '    CPyTagged cpy_r_i0;\n',
            'CPyL0: ;\n',
            '    cpy_r_i0 = 10;\n',
            '}\n',
        ],
                                   result,
                                   msg='Generated code invalid')
Example #2
0
def split_blocks_at_uninits(env: Environment,
                            blocks: List[BasicBlock],
                            pre_must_defined: 'AnalysisDict[Value]') -> List[BasicBlock]:
    new_blocks = []  # type: List[BasicBlock]

    # First split blocks on ops that may raise.
    for block in blocks:
        ops = block.ops
        block.ops = []
        cur_block = block
        new_blocks.append(cur_block)

        for i, op in enumerate(ops):
            defined = pre_must_defined[block, i]
            for src in op.unique_sources():
                # If a register operand is not guaranteed to be
                # initialized is an operand to something other than a
                # check that it is defined, insert a check.

                # Note that for register operand in a LoadAddress op,
                # we should be able to use it without initialization
                # as we may need to use its address to update itself
                if (isinstance(src, Register) and src not in defined
                        and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR)
                        and not isinstance(op, LoadAddress)):
                    new_block, error_block = BasicBlock(), BasicBlock()
                    new_block.error_handler = error_block.error_handler = cur_block.error_handler
                    new_blocks += [error_block, new_block]

                    env.vars_needing_init.add(src)

                    cur_block.ops.append(Branch(src,
                                                true_label=error_block,
                                                false_label=new_block,
                                                op=Branch.IS_ERROR,
                                                line=op.line))
                    raise_std = RaiseStandardError(
                        RaiseStandardError.UNBOUND_LOCAL_ERROR,
                        "local variable '{}' referenced before assignment".format(src.name),
                        op.line)
                    env.add_op(raise_std)
                    error_block.ops.append(raise_std)
                    error_block.ops.append(Unreachable())
                    cur_block = new_block
            cur_block.ops.append(op)

    return new_blocks
Example #3
0
class LowLevelIRBuilder:
    def __init__(
        self,
        current_module: str,
        mapper: Mapper,
    ) -> None:
        self.current_module = current_module
        self.mapper = mapper
        self.environment = Environment()
        self.blocks = []  # type: List[BasicBlock]
        # Stack of except handler entry blocks
        self.error_handlers = [None]  # type: List[Optional[BasicBlock]]

    # Basic operations

    def add(self, op: Op) -> Value:
        """Add an op."""
        assert not self.blocks[-1].terminated, "Can't add to finished block"

        self.blocks[-1].ops.append(op)
        if isinstance(op, RegisterOp):
            self.environment.add_op(op)
        return op

    def goto(self, target: BasicBlock) -> None:
        """Add goto to a basic block."""
        if not self.blocks[-1].terminated:
            self.add(Goto(target))

    def activate_block(self, block: BasicBlock) -> None:
        """Add a basic block and make it the active one (target of adds)."""
        if self.blocks:
            assert self.blocks[-1].terminated

        block.error_handler = self.error_handlers[-1]
        self.blocks.append(block)

    def goto_and_activate(self, block: BasicBlock) -> None:
        """Add goto a block and make it the active block."""
        self.goto(block)
        self.activate_block(block)

    def push_error_handler(self, handler: Optional[BasicBlock]) -> None:
        self.error_handlers.append(handler)

    def pop_error_handler(self) -> Optional[BasicBlock]:
        return self.error_handlers.pop()

    def alloc_temp(self, type: RType) -> Register:
        return self.environment.add_temp(type)

    # Type conversions

    def box(self, src: Value) -> Value:
        if src.type.is_unboxed:
            return self.add(Box(src))
        else:
            return src

    def unbox_or_cast(self, src: Value, target_type: RType, line: int) -> Value:
        if target_type.is_unboxed:
            return self.add(Unbox(src, target_type, line))
        else:
            return self.add(Cast(src, target_type, line))

    def coerce(self, src: Value, target_type: RType, line: int, force: bool = False) -> Value:
        """Generate a coercion/cast from one type to other (only if needed).

        For example, int -> object boxes the source int; int -> int emits nothing;
        object -> int unboxes the object. All conversions preserve object value.

        If force is true, always generate an op (even if it is just an assignment) so
        that the result will have exactly target_type as the type.

        Returns the register with the converted value (may be same as src).
        """
        if src.type.is_unboxed and not target_type.is_unboxed:
            return self.box(src)
        if ((src.type.is_unboxed and target_type.is_unboxed)
                and not is_runtime_subtype(src.type, target_type)):
            # To go from one unboxed type to another, we go through a boxed
            # in-between value, for simplicity.
            tmp = self.box(src)
            return self.unbox_or_cast(tmp, target_type, line)
        if ((not src.type.is_unboxed and target_type.is_unboxed)
                or not is_subtype(src.type, target_type)):
            return self.unbox_or_cast(src, target_type, line)
        elif force:
            tmp = self.alloc_temp(target_type)
            self.add(Assign(tmp, src))
            return tmp
        return src

    # Attribute access

    def get_attr(self, obj: Value, attr: str, result_type: RType, line: int) -> Value:
        """Get a native or Python attribute of an object."""
        if (isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class
                and obj.type.class_ir.has_attr(attr)):
            return self.add(GetAttr(obj, attr, line))
        elif isinstance(obj.type, RUnion):
            return self.union_get_attr(obj, obj.type, attr, result_type, line)
        else:
            return self.py_get_attr(obj, attr, line)

    def union_get_attr(self,
                       obj: Value,
                       rtype: RUnion,
                       attr: str,
                       result_type: RType,
                       line: int) -> Value:
        """Get an attribute of an object with a union type."""

        def get_item_attr(value: Value) -> Value:
            return self.get_attr(value, attr, result_type, line)

        return self.decompose_union_helper(obj, rtype, result_type, get_item_attr, line)

    def py_get_attr(self, obj: Value, attr: str, line: int) -> Value:
        """Get a Python attribute (slow).

        Prefer get_attr() which generates optimized code for native classes.
        """
        key = self.load_static_unicode(attr)
        return self.add(PrimitiveOp([obj, key], py_getattr_op, line))

    # isinstance() checks

    def isinstance_helper(self, obj: Value, class_irs: List[ClassIR], line: int) -> Value:
        """Fast path for isinstance() that checks against a list of native classes."""
        if not class_irs:
            return self.primitive_op(false_op, [], line)
        ret = self.isinstance_native(obj, class_irs[0], line)
        for class_ir in class_irs[1:]:
            def other() -> Value:
                return self.isinstance_native(obj, class_ir, line)
            ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line)
        return ret

    def isinstance_native(self, obj: Value, class_ir: ClassIR, line: int) -> Value:
        """Fast isinstance() check for a native class.

        If there are three or fewer concrete (non-trait) classes among the class
        and all its children, use even faster type comparison checks `type(obj)
        is typ`.
        """
        concrete = all_concrete_classes(class_ir)
        if concrete is None or len(concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1:
            return self.primitive_op(fast_isinstance_op,
                                     [obj, self.get_native_type(class_ir)],
                                     line)
        if not concrete:
            # There can't be any concrete instance that matches this.
            return self.primitive_op(false_op, [], line)
        type_obj = self.get_native_type(concrete[0])
        ret = self.primitive_op(type_is_op, [obj, type_obj], line)
        for c in concrete[1:]:
            def other() -> Value:
                return self.primitive_op(type_is_op, [obj, self.get_native_type(c)], line)
            ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line)
        return ret

    # Calls

    def py_call(self,
                function: Value,
                arg_values: List[Value],
                line: int,
                arg_kinds: Optional[List[int]] = None,
                arg_names: Optional[Sequence[Optional[str]]] = None) -> Value:
        """Call a Python function (non-native and slow).

        Use py_call_op or py_call_with_kwargs_op for Python function call.
        """
        # If all arguments are positional, we can use py_call_op.
        if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds):
            return self.primitive_op(py_call_op, [function] + arg_values, line)

        # Otherwise fallback to py_call_with_kwargs_op.
        assert arg_names is not None

        pos_arg_values = []
        kw_arg_key_value_pairs = []  # type: List[DictEntry]
        star_arg_values = []
        for value, kind, name in zip(arg_values, arg_kinds, arg_names):
            if kind == ARG_POS:
                pos_arg_values.append(value)
            elif kind == ARG_NAMED:
                assert name is not None
                key = self.load_static_unicode(name)
                kw_arg_key_value_pairs.append((key, value))
            elif kind == ARG_STAR:
                star_arg_values.append(value)
            elif kind == ARG_STAR2:
                # NOTE: mypy currently only supports a single ** arg, but python supports multiple.
                # This code supports multiple primarily to make the logic easier to follow.
                kw_arg_key_value_pairs.append((None, value))
            else:
                assert False, ("Argument kind should not be possible:", kind)

        if len(star_arg_values) == 0:
            # We can directly construct a tuple if there are no star args.
            pos_args_tuple = self.primitive_op(new_tuple_op, pos_arg_values, line)
        else:
            # Otherwise we construct a list and call extend it with the star args, since tuples
            # don't have an extend method.
            pos_args_list = self.primitive_op(new_list_op, pos_arg_values, line)
            for star_arg_value in star_arg_values:
                self.call_c(list_extend_op, [pos_args_list, star_arg_value], line)
            pos_args_tuple = self.call_c(list_tuple_op, [pos_args_list], line)

        kw_args_dict = self.make_dict(kw_arg_key_value_pairs, line)

        return self.primitive_op(
            py_call_with_kwargs_op, [function, pos_args_tuple, kw_args_dict], line)

    def py_method_call(self,
                       obj: Value,
                       method_name: str,
                       arg_values: List[Value],
                       line: int,
                       arg_kinds: Optional[List[int]],
                       arg_names: Optional[Sequence[Optional[str]]]) -> Value:
        """Call a Python method (non-native and slow)."""
        if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds):
            method_name_reg = self.load_static_unicode(method_name)
            return self.primitive_op(py_method_call_op, [obj, method_name_reg] + arg_values, line)
        else:
            method = self.py_get_attr(obj, method_name, line)
            return self.py_call(method, arg_values, line, arg_kinds=arg_kinds, arg_names=arg_names)

    def call(self,
             decl: FuncDecl,
             args: Sequence[Value],
             arg_kinds: List[int],
             arg_names: Sequence[Optional[str]],
             line: int) -> Value:
        """Call a native function."""
        # Normalize args to positionals.
        args = self.native_args_to_positional(
            args, arg_kinds, arg_names, decl.sig, line)
        return self.add(Call(decl, args, line))

    def native_args_to_positional(self,
                                  args: Sequence[Value],
                                  arg_kinds: List[int],
                                  arg_names: Sequence[Optional[str]],
                                  sig: FuncSignature,
                                  line: int) -> List[Value]:
        """Prepare arguments for a native call.

        Given args/kinds/names and a target signature for a native call, map
        keyword arguments to their appropriate place in the argument list,
        fill in error values for unspecified default arguments,
        package arguments that will go into *args/**kwargs into a tuple/dict,
        and coerce arguments to the appropriate type.
        """

        sig_arg_kinds = [arg.kind for arg in sig.args]
        sig_arg_names = [arg.name for arg in sig.args]
        formal_to_actual = map_actuals_to_formals(arg_kinds,
                                                  arg_names,
                                                  sig_arg_kinds,
                                                  sig_arg_names,
                                                  lambda n: AnyType(TypeOfAny.special_form))

        # Flatten out the arguments, loading error values for default
        # arguments, constructing tuples/dicts for star args, and
        # coercing everything to the expected type.
        output_args = []
        for lst, arg in zip(formal_to_actual, sig.args):
            output_arg = None
            if arg.kind == ARG_STAR:
                output_arg = self.primitive_op(new_tuple_op, [args[i] for i in lst], line)
            elif arg.kind == ARG_STAR2:
                dict_entries = [(self.load_static_unicode(cast(str, arg_names[i])), args[i])
                                for i in lst]
                output_arg = self.make_dict(dict_entries, line)
            elif not lst:
                output_arg = self.add(LoadErrorValue(arg.type, is_borrowed=True))
            else:
                output_arg = args[lst[0]]
            output_args.append(self.coerce(output_arg, arg.type, line))

        return output_args

    def gen_method_call(self,
                        base: Value,
                        name: str,
                        arg_values: List[Value],
                        result_type: Optional[RType],
                        line: int,
                        arg_kinds: Optional[List[int]] = None,
                        arg_names: Optional[List[Optional[str]]] = None) -> Value:
        """Generate either a native or Python method call."""
        # If arg_kinds contains values other than arg_pos and arg_named, then fallback to
        # Python method call.
        if (arg_kinds is not None
                and not all(kind in (ARG_POS, ARG_NAMED) for kind in arg_kinds)):
            return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names)

        # If the base type is one of ours, do a MethodCall
        if (isinstance(base.type, RInstance) and base.type.class_ir.is_ext_class
                and not base.type.class_ir.builtin_base):
            if base.type.class_ir.has_method(name):
                decl = base.type.class_ir.method_decl(name)
                if arg_kinds is None:
                    assert arg_names is None, "arg_kinds not present but arg_names is"
                    arg_kinds = [ARG_POS for _ in arg_values]
                    arg_names = [None for _ in arg_values]
                else:
                    assert arg_names is not None, "arg_kinds present but arg_names is not"

                # Normalize args to positionals.
                assert decl.bound_sig
                arg_values = self.native_args_to_positional(
                    arg_values, arg_kinds, arg_names, decl.bound_sig, line)
                return self.add(MethodCall(base, name, arg_values, line))
            elif base.type.class_ir.has_attr(name):
                function = self.add(GetAttr(base, name, line))
                return self.py_call(function, arg_values, line,
                                    arg_kinds=arg_kinds, arg_names=arg_names)

        elif isinstance(base.type, RUnion):
            return self.union_method_call(base, base.type, name, arg_values, result_type, line,
                                          arg_kinds, arg_names)

        # Try to do a special-cased method call
        if not arg_kinds or arg_kinds == [ARG_POS] * len(arg_values):
            target = self.translate_special_method_call(base, name, arg_values, result_type, line)
            if target:
                return target

        # Fall back to Python method call
        return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names)

    def union_method_call(self,
                          base: Value,
                          obj_type: RUnion,
                          name: str,
                          arg_values: List[Value],
                          return_rtype: Optional[RType],
                          line: int,
                          arg_kinds: Optional[List[int]],
                          arg_names: Optional[List[Optional[str]]]) -> Value:
        """Generate a method call with a union type for the object."""
        # Union method call needs a return_rtype for the type of the output register.
        # If we don't have one, use object_rprimitive.
        return_rtype = return_rtype or object_rprimitive

        def call_union_item(value: Value) -> Value:
            return self.gen_method_call(value, name, arg_values, return_rtype, line,
                                        arg_kinds, arg_names)

        return self.decompose_union_helper(base, obj_type, return_rtype, call_union_item, line)

    # Loading various values

    def none(self) -> Value:
        """Load unboxed None value (type: none_rprimitive)."""
        return self.add(PrimitiveOp([], none_op, line=-1))

    def none_object(self) -> Value:
        """Load Python None value (type: object_rprimitive)."""
        return self.add(PrimitiveOp([], none_object_op, line=-1))

    def literal_static_name(self, value: Union[int, float, complex, str, bytes]) -> str:
        return STATIC_PREFIX + self.mapper.literal_static_name(self.current_module, value)

    def load_static_int(self, value: int) -> Value:
        """Loads a static integer Python 'int' object into a register."""
        if abs(value) > MAX_LITERAL_SHORT_INT:
            identifier = self.literal_static_name(value)
            return self.add(LoadGlobal(int_rprimitive, identifier, ann=value))
        else:
            return self.add(LoadInt(value))

    def load_static_float(self, value: float) -> Value:
        """Loads a static float value into a register."""
        identifier = self.literal_static_name(value)
        return self.add(LoadGlobal(float_rprimitive, identifier, ann=value))

    def load_static_bytes(self, value: bytes) -> Value:
        """Loads a static bytes value into a register."""
        identifier = self.literal_static_name(value)
        return self.add(LoadGlobal(object_rprimitive, identifier, ann=value))

    def load_static_complex(self, value: complex) -> Value:
        """Loads a static complex value into a register."""
        identifier = self.literal_static_name(value)
        return self.add(LoadGlobal(object_rprimitive, identifier, ann=value))

    def load_static_unicode(self, value: str) -> Value:
        """Loads a static unicode value into a register.

        This is useful for more than just unicode literals; for example, method calls
        also require a PyObject * form for the name of the method.
        """
        identifier = self.literal_static_name(value)
        return self.add(LoadGlobal(str_rprimitive, identifier, ann=value))

    def load_static_checked(self, typ: RType, identifier: str, module_name: Optional[str] = None,
                            namespace: str = NAMESPACE_STATIC,
                            line: int = -1,
                            error_msg: Optional[str] = None) -> Value:
        if error_msg is None:
            error_msg = "name '{}' is not defined".format(identifier)
        ok_block, error_block = BasicBlock(), BasicBlock()
        value = self.add(LoadStatic(typ, identifier, module_name, namespace, line=line))
        self.add(Branch(value, error_block, ok_block, Branch.IS_ERROR, rare=True))
        self.activate_block(error_block)
        self.add(RaiseStandardError(RaiseStandardError.NAME_ERROR,
                                    error_msg,
                                    line))
        self.add(Unreachable())
        self.activate_block(ok_block)
        return value

    def load_module(self, name: str) -> Value:
        return self.add(LoadStatic(object_rprimitive, name, namespace=NAMESPACE_MODULE))

    def get_native_type(self, cls: ClassIR) -> Value:
        """Load native type object."""
        fullname = '%s.%s' % (cls.module_name, cls.name)
        return self.load_native_type_object(fullname)

    def load_native_type_object(self, fullname: str) -> Value:
        module, name = fullname.rsplit('.', 1)
        return self.add(LoadStatic(object_rprimitive, name, module, NAMESPACE_TYPE))

    # Other primitive operations

    def primitive_op(self, desc: OpDescription, args: List[Value], line: int) -> Value:
        assert desc.result_type is not None
        coerced = []
        for i, arg in enumerate(args):
            formal_type = self.op_arg_type(desc, i)
            arg = self.coerce(arg, formal_type, line)
            coerced.append(arg)
        target = self.add(PrimitiveOp(coerced, desc, line))
        return target

    def matching_primitive_op(self,
                              candidates: List[OpDescription],
                              args: List[Value],
                              line: int,
                              result_type: Optional[RType] = None) -> Optional[Value]:
        # Find the highest-priority primitive op that matches.
        matching = None  # type: Optional[OpDescription]
        for desc in candidates:
            if len(desc.arg_types) != len(args):
                continue
            if all(is_subtype(actual.type, formal)
                   for actual, formal in zip(args, desc.arg_types)):
                if matching:
                    assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % (
                        matching, desc)
                    if desc.priority > matching.priority:
                        matching = desc
                else:
                    matching = desc
        if matching:
            target = self.primitive_op(matching, args, line)
            if result_type and not is_runtime_subtype(target.type, result_type):
                if is_none_rprimitive(result_type):
                    # Special case None return. The actual result may actually be a bool
                    # and so we can't just coerce it.
                    target = self.none()
                else:
                    target = self.coerce(target, result_type, line)
            return target
        return None

    def binary_op(self,
                  lreg: Value,
                  rreg: Value,
                  expr_op: str,
                  line: int) -> Value:
        # Special case == and != when we can resolve the method call statically.
        value = None
        if expr_op in ('==', '!='):
            value = self.translate_eq_cmp(lreg, rreg, expr_op, line)
        if value is not None:
            return value

        # generate fast binary logic ops on short ints
        if (is_short_int_rprimitive(lreg.type) and is_short_int_rprimitive(rreg.type)
                and expr_op in int_logical_op_mapping.keys()):
            return self.binary_int_op(bool_rprimitive, lreg, rreg,
                                      int_logical_op_mapping[expr_op][0], line)

        call_c_ops_candidates = c_binary_ops.get(expr_op, [])
        target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line)
        if target:
            return target
        ops = binary_ops.get(expr_op, [])
        target = self.matching_primitive_op(ops, [lreg, rreg], line)
        assert target, 'Unsupported binary operation: %s' % expr_op
        return target

    def check_tagged_short_int(self, val: Value, line: int) -> Value:
        """Check if a tagged integer is a short integer"""
        int_tag = self.add(LoadInt(1, line, rtype=c_pyssize_t_rprimitive))
        bitwise_and = self.binary_int_op(c_pyssize_t_rprimitive, val,
                                         int_tag, BinaryIntOp.AND, line)
        zero = self.add(LoadInt(0, line, rtype=c_pyssize_t_rprimitive))
        check = self.binary_int_op(bool_rprimitive, bitwise_and, zero, BinaryIntOp.EQ, line)
        return check

    def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
        """Compare two tagged integers using given op"""
        op_type, c_func_desc = int_logical_op_mapping[op]
        result = self.alloc_temp(bool_rprimitive)
        short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
        check = self.check_tagged_short_int(lhs, line)
        branch = Branch(check, short_int_block, int_block, Branch.BOOL_EXPR)
        branch.negated = False
        self.add(branch)
        self.activate_block(short_int_block)
        eq = self.binary_int_op(bool_rprimitive, lhs, rhs, op_type, line)
        self.add(Assign(result, eq, line))
        self.goto(out)
        self.activate_block(int_block)
        call = self.call_c(c_func_desc, [lhs, rhs], line)
        self.add(Assign(result, call, line))
        self.goto_and_activate(out)
        return result

    def unary_op(self,
                 lreg: Value,
                 expr_op: str,
                 line: int) -> Value:
        call_c_ops_candidates = c_unary_ops.get(expr_op, [])
        target = self.matching_call_c(call_c_ops_candidates, [lreg], line)
        if target:
            return target
        ops = unary_ops.get(expr_op, [])
        target = self.matching_primitive_op(ops, [lreg], line)
        assert target, 'Unsupported unary operation: %s' % expr_op
        return target

    def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value:
        result = None  # type: Union[Value, None]
        keys = []  # type: List[Value]
        values = []  # type: List[Value]
        for key, value in key_value_pairs:
            if key is not None:
                # key:value
                if result is None:
                    keys.append(key)
                    values.append(value)
                    continue

                self.translate_special_method_call(
                    result,
                    '__setitem__',
                    [key, value],
                    result_type=None,
                    line=line)
            else:
                # **value
                if result is None:
                    result = self._create_dict(keys, values, line)

                self.call_c(
                    dict_update_in_display_op,
                    [result, value],
                    line=line
                )

        if result is None:
            result = self._create_dict(keys, values, line)

        return result

    def builtin_call(self,
                     args: List[Value],
                     fn_op: str,
                     line: int) -> Value:
        call_c_ops_candidates = c_function_ops.get(fn_op, [])
        target = self.matching_call_c(call_c_ops_candidates, args, line)
        if target:
            return target
        ops = func_ops.get(fn_op, [])
        target = self.matching_primitive_op(ops, args, line)
        assert target, 'Unsupported builtin function: %s' % fn_op
        return target

    def shortcircuit_helper(self, op: str,
                            expr_type: RType,
                            left: Callable[[], Value],
                            right: Callable[[], Value], line: int) -> Value:
        # Having actual Phi nodes would be really nice here!
        target = self.alloc_temp(expr_type)
        # left_body takes the value of the left side, right_body the right
        left_body, right_body, next = BasicBlock(), BasicBlock(), BasicBlock()
        # true_body is taken if the left is true, false_body if it is false.
        # For 'and' the value is the right side if the left is true, and for 'or'
        # it is the right side if the left is false.
        true_body, false_body = (
            (right_body, left_body) if op == 'and' else (left_body, right_body))

        left_value = left()
        self.add_bool_branch(left_value, true_body, false_body)

        self.activate_block(left_body)
        left_coerced = self.coerce(left_value, expr_type, line)
        self.add(Assign(target, left_coerced))
        self.goto(next)

        self.activate_block(right_body)
        right_value = right()
        right_coerced = self.coerce(right_value, expr_type, line)
        self.add(Assign(target, right_coerced))
        self.goto(next)

        self.activate_block(next)
        return target

    def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None:
        if is_runtime_subtype(value.type, int_rprimitive):
            zero = self.add(LoadInt(0))
            value = self.binary_op(value, zero, '!=', value.line)
        elif is_same_type(value.type, list_rprimitive):
            length = self.primitive_op(list_len_op, [value], value.line)
            zero = self.add(LoadInt(0))
            value = self.binary_op(length, zero, '!=', value.line)
        elif (isinstance(value.type, RInstance) and value.type.class_ir.is_ext_class
                and value.type.class_ir.has_method('__bool__')):
            # Directly call the __bool__ method on classes that have it.
            value = self.gen_method_call(value, '__bool__', [], bool_rprimitive, value.line)
        else:
            value_type = optional_value_type(value.type)
            if value_type is not None:
                is_none = self.binary_op(value, self.none_object(), 'is not', value.line)
                branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
                self.add(branch)
                always_truthy = False
                if isinstance(value_type, RInstance):
                    # check whether X.__bool__ is always just the default (object.__bool__)
                    if (not value_type.class_ir.has_method('__bool__')
                            and value_type.class_ir.is_method_final('__bool__')):
                        always_truthy = True

                if not always_truthy:
                    # Optional[X] where X may be falsey and requires a check
                    branch.true = BasicBlock()
                    self.activate_block(branch.true)
                    # unbox_or_cast instead of coerce because we want the
                    # type to change even if it is a subtype.
                    remaining = self.unbox_or_cast(value, value_type, value.line)
                    self.add_bool_branch(remaining, true, false)
                return
            elif not is_same_type(value.type, bool_rprimitive):
                value = self.primitive_op(bool_op, [value], value.line)
        self.add(Branch(value, true, false, Branch.BOOL_EXPR))

    def call_c(self,
               desc: CFunctionDescription,
               args: List[Value],
               line: int,
               result_type: Optional[RType] = None) -> Value:
        # handle void function via singleton RVoid instance
        coerced = []
        # coerce fixed number arguments
        for i in range(min(len(args), len(desc.arg_types))):
            formal_type = desc.arg_types[i]
            arg = args[i]
            arg = self.coerce(arg, formal_type, line)
            coerced.append(arg)
        # reorder args if necessary
        if desc.ordering is not None:
            assert desc.var_arg_type is None
            coerced = [coerced[i] for i in desc.ordering]
        # coerce any var_arg
        var_arg_idx = -1
        if desc.var_arg_type is not None:
            var_arg_idx = len(desc.arg_types)
            for i in range(len(desc.arg_types), len(args)):
                arg = args[i]
                arg = self.coerce(arg, desc.var_arg_type, line)
                coerced.append(arg)
        target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals,
                                desc.error_kind, line, var_arg_idx))
        if desc.truncated_type is None:
            result = target
        else:
            truncate = self.add(Truncate(target, desc.return_type, desc.truncated_type))
            result = truncate
        if result_type and not is_runtime_subtype(result.type, result_type):
            if is_none_rprimitive(result_type):
                # Special case None return. The actual result may actually be a bool
                # and so we can't just coerce it.
                result = self.none()
            else:
                result = self.coerce(target, result_type, line)
        return result

    def matching_call_c(self,
                        candidates: List[CFunctionDescription],
                        args: List[Value],
                        line: int,
                        result_type: Optional[RType] = None) -> Optional[Value]:
        # TODO: this function is very similar to matching_primitive_op
        # we should remove the old one or refactor both them into only as we move forward
        matching = None  # type: Optional[CFunctionDescription]
        for desc in candidates:
            if len(desc.arg_types) != len(args):
                continue
            if all(is_subtype(actual.type, formal)
                   for actual, formal in zip(args, desc.arg_types)):
                if matching:
                    assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % (
                        matching, desc)
                    if desc.priority > matching.priority:
                        matching = desc
                else:
                    matching = desc
        if matching:
            target = self.call_c(matching, args, line, result_type)
            return target
        return None

    def binary_int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
        return self.add(BinaryIntOp(type, lhs, rhs, op, line))

    # Internal helpers

    def decompose_union_helper(self,
                               obj: Value,
                               rtype: RUnion,
                               result_type: RType,
                               process_item: Callable[[Value], Value],
                               line: int) -> Value:
        """Generate isinstance() + specialized operations for union items.

        Say, for Union[A, B] generate ops resembling this (pseudocode):

            if isinstance(obj, A):
                result = <result of process_item(cast(A, obj)>
            else:
                result = <result of process_item(cast(B, obj)>

        Args:
            obj: value with a union type
            rtype: the union type
            result_type: result of the operation
            process_item: callback to generate op for a single union item (arg is coerced
                to union item type)
            line: line number
        """
        # TODO: Optimize cases where a single operation can handle multiple union items
        #     (say a method is implemented in a common base class)
        fast_items = []
        rest_items = []
        for item in rtype.items:
            if isinstance(item, RInstance):
                fast_items.append(item)
            else:
                # For everything but RInstance we fall back to C API
                rest_items.append(item)
        exit_block = BasicBlock()
        result = self.alloc_temp(result_type)
        for i, item in enumerate(fast_items):
            more_types = i < len(fast_items) - 1 or rest_items
            if more_types:
                # We are not at the final item so we need one more branch
                op = self.isinstance_native(obj, item.class_ir, line)
                true_block, false_block = BasicBlock(), BasicBlock()
                self.add_bool_branch(op, true_block, false_block)
                self.activate_block(true_block)
            coerced = self.coerce(obj, item, line)
            temp = process_item(coerced)
            temp2 = self.coerce(temp, result_type, line)
            self.add(Assign(result, temp2))
            self.goto(exit_block)
            if more_types:
                self.activate_block(false_block)
        if rest_items:
            # For everything else we use generic operation. Use force=True to drop the
            # union type.
            coerced = self.coerce(obj, object_rprimitive, line, force=True)
            temp = process_item(coerced)
            temp2 = self.coerce(temp, result_type, line)
            self.add(Assign(result, temp2))
            self.goto(exit_block)
        self.activate_block(exit_block)
        return result

    def op_arg_type(self, desc: OpDescription, n: int) -> RType:
        if n >= len(desc.arg_types):
            assert desc.is_var_arg
            return desc.arg_types[-1]
        return desc.arg_types[n]

    def translate_special_method_call(self,
                                      base_reg: Value,
                                      name: str,
                                      args: List[Value],
                                      result_type: Optional[RType],
                                      line: int) -> Optional[Value]:
        """Translate a method call which is handled nongenerically.

        These are special in the sense that we have code generated specifically for them.
        They tend to be method calls which have equivalents in C that are more direct
        than calling with the PyObject api.

        Return None if no translation found; otherwise return the target register.
        """
        ops = method_ops.get(name, [])
        call_c_ops_candidates = c_method_call_ops.get(name, [])
        call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args,
                                         line, result_type)
        if call_c_op is not None:
            return call_c_op
        return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type)

    def translate_eq_cmp(self,
                         lreg: Value,
                         rreg: Value,
                         expr_op: str,
                         line: int) -> Optional[Value]:
        """Add a equality comparison operation.

        Args:
            expr_op: either '==' or '!='
        """
        ltype = lreg.type
        rtype = rreg.type
        if not (isinstance(ltype, RInstance) and ltype == rtype):
            return None

        class_ir = ltype.class_ir
        # Check whether any subclasses of the operand redefines __eq__
        # or it might be redefined in a Python parent class or by
        # dataclasses
        cmp_varies_at_runtime = (
            not class_ir.is_method_final('__eq__')
            or not class_ir.is_method_final('__ne__')
            or class_ir.inherits_python
            or class_ir.is_augmented
        )

        if cmp_varies_at_runtime:
            # We might need to call left.__eq__(right) or right.__eq__(left)
            # depending on which is the more specific type.
            return None

        if not class_ir.has_method('__eq__'):
            # There's no __eq__ defined, so just use object identity.
            identity_ref_op = 'is' if expr_op == '==' else 'is not'
            return self.binary_op(lreg, rreg, identity_ref_op, line)

        return self.gen_method_call(
            lreg,
            op_methods[expr_op],
            [rreg],
            ltype,
            line
        )

    def _create_dict(self,
                     keys: List[Value],
                     values: List[Value],
                     line: int) -> Value:
        """Create a dictionary(possibly empty) using keys and values"""
        # keys and values should have the same number of items
        size = len(keys)
        if size > 0:
            load_size_op = self.add(LoadInt(size, -1, c_pyssize_t_rprimitive))
            # merge keys and values
            items = [i for t in list(zip(keys, values)) for i in t]
            return self.call_c(dict_build_op, [load_size_op] + items, line)
        else:
            return self.call_c(dict_new_op, [], line)
Example #4
0
class TestFunctionEmitterVisitor(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), int_rprimitive)
        self.m = self.env.add_local(Var('m'), int_rprimitive)
        self.k = self.env.add_local(Var('k'), int_rprimitive)
        self.l = self.env.add_local(Var('l'), list_rprimitive)  # noqa
        self.ll = self.env.add_local(Var('ll'), list_rprimitive)
        self.o = self.env.add_local(Var('o'), object_rprimitive)
        self.o2 = self.env.add_local(Var('o2'), object_rprimitive)
        self.d = self.env.add_local(Var('d'), dict_rprimitive)
        self.b = self.env.add_local(Var('b'), bool_rprimitive)
        self.t = self.env.add_local(Var('t'),
                                    RTuple([int_rprimitive, bool_rprimitive]))
        self.tt = self.env.add_local(
            Var('tt'),
            RTuple(
                [RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive]))
        ir = ClassIR('A', 'mod')
        ir.attributes = OrderedDict([('x', bool_rprimitive),
                                     ('y', int_rprimitive)])
        compute_vtable(ir)
        ir.mro = [ir]
        self.r = self.env.add_local(Var('r'), RInstance(ir))

        self.context = EmitterContext(NameGenerator([['mod']]))
        self.emitter = Emitter(self.context, self.env)
        self.declarations = Emitter(self.context, self.env)
        self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations,
                                              'prog.py', 'prog')

    def test_goto(self) -> None:
        self.assert_emit(Goto(BasicBlock(2)), "goto CPyL2;")

    def test_return(self) -> None:
        self.assert_emit(Return(self.m), "return cpy_r_m;")

    def test_load_int(self) -> None:
        self.assert_emit(LoadInt(5), "cpy_r_r0 = 10;")
        self.assert_emit(LoadInt(5, -1, c_int_rprimitive), "cpy_r_r00 = 5;")

    def test_tuple_get(self) -> None:
        self.assert_emit(TupleGet(self.t, 1, 0), 'cpy_r_r0 = cpy_r_t.f1;')

    def test_load_None(self) -> None:
        self.assert_emit(PrimitiveOp([], none_object_op, 0),
                         "cpy_r_r0 = Py_None;")

    def test_load_True(self) -> None:
        self.assert_emit(PrimitiveOp([], true_op, 0), "cpy_r_r0 = 1;")

    def test_load_False(self) -> None:
        self.assert_emit(PrimitiveOp([], false_op, 0), "cpy_r_r0 = 0;")

    def test_assign_int(self) -> None:
        self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;")

    def test_int_add(self) -> None:
        self.assert_emit_binary_op(
            '+', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);")

    def test_int_sub(self) -> None:
        self.assert_emit_binary_op(
            '-', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);")

    def test_int_neg(self) -> None:
        self.assert_emit(
            CallC(int_neg_op.c_function_name, [self.m], int_neg_op.return_type,
                  int_neg_op.steals, int_neg_op.error_kind, 55),
            "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);")

    def test_list_len(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l], list_len_op, 55), """Py_ssize_t __tmp1;
                            __tmp1 = PyList_GET_SIZE(cpy_r_l);
                            cpy_r_r0 = CPyTagged_ShortFromSsize_t(__tmp1);
                         """)

    def test_branch(self) -> None:
        self.assert_emit(
            Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR),
            """if (cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)
        b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR)
        b.negated = True
        self.assert_emit(
            b, """if (!cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)

    def test_call(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([RuntimeArg('m', int_rprimitive)], int_rprimitive))
        self.assert_emit(Call(decl, [self.m], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m);")

    def test_call_two_args(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([
                RuntimeArg('m', int_rprimitive),
                RuntimeArg('n', int_rprimitive)
            ], int_rprimitive))
        self.assert_emit(Call(decl, [self.m, self.k], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);")

    def test_inc_ref(self) -> None:
        self.assert_emit(IncRef(self.m), "CPyTagged_IncRef(cpy_r_m);")

    def test_dec_ref(self) -> None:
        self.assert_emit(DecRef(self.m), "CPyTagged_DecRef(cpy_r_m);")

    def test_dec_ref_tuple(self) -> None:
        self.assert_emit(DecRef(self.t), 'CPyTagged_DecRef(cpy_r_t.f0);')

    def test_dec_ref_tuple_nested(self) -> None:
        self.assert_emit(DecRef(self.tt), 'CPyTagged_DecRef(cpy_r_tt.f0.f0);')

    def test_list_get_item(self) -> None:
        self.assert_emit(PrimitiveOp([self.m, self.k], list_get_item_op, 55),
                         """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""")

    def test_list_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l, self.n, self.o], list_set_item_op, 55),
            """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""")

    def test_box(self) -> None:
        self.assert_emit(Box(self.n),
                         """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""")

    def test_unbox(self) -> None:
        self.assert_emit(
            Unbox(self.m, int_rprimitive, 55),
            """if (likely(PyLong_Check(cpy_r_m)))
                                cpy_r_r0 = CPyTagged_FromObject(cpy_r_m);
                            else {
                                CPy_TypeError("int", cpy_r_m);
                                cpy_r_r0 = CPY_INT_TAG;
                            }
                         """)

    def test_new_list(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.n, self.m], new_list_op, 55),
            """cpy_r_r0 = PyList_New(2);
                            if (likely(cpy_r_r0 != NULL)) {
                                PyList_SET_ITEM(cpy_r_r0, 0, cpy_r_n);
                                PyList_SET_ITEM(cpy_r_r0, 1, cpy_r_m);
                            }
                         """)

    def test_list_append(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.l, self.o], list_append_op, 1),
            """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o) >= 0;""")

    def test_get_attr(self) -> None:
        self.assert_emit(
            GetAttr(self.r, 'y', 1),
            """cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
               if (unlikely(((mod___AObject *)cpy_r_r)->_y == CPY_INT_TAG)) {
                   PyErr_SetString(PyExc_AttributeError, "attribute 'y' of 'A' undefined");
               } else {
                   CPyTagged_IncRef(((mod___AObject *)cpy_r_r)->_y);
               }
            """)

    def test_set_attr(self) -> None:
        self.assert_emit(
            SetAttr(self.r, 'y', self.m, 1),
            """if (((mod___AObject *)cpy_r_r)->_y != CPY_INT_TAG) {
                   CPyTagged_DecRef(((mod___AObject *)cpy_r_r)->_y);
               }
               ((mod___AObject *)cpy_r_r)->_y = cpy_r_m;
               cpy_r_r0 = 1;
            """)

    def test_dict_get_item(self) -> None:
        self.assert_emit(PrimitiveOp([self.d, self.o2], dict_get_item_op, 1),
                         """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""")

    def test_dict_set_item(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.d, self.o, self.o2], dict_set_item_op, 1),
            """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2) >= 0;""")

    def test_dict_update(self) -> None:
        self.assert_emit(
            PrimitiveOp([self.d, self.o], dict_update_op, 1),
            """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o) >= 0;""")

    def test_new_dict(self) -> None:
        self.assert_emit(PrimitiveOp([], new_dict_op, 1),
                         """cpy_r_r0 = PyDict_New();""")

    def test_dict_contains(self) -> None:
        self.assert_emit_binary_op(
            'in', self.b, self.o, self.d,
            """int __tmp1 = PyDict_Contains(cpy_r_d, cpy_r_o);
               if (__tmp1 < 0)
                   cpy_r_r0 = 2;
               else
                   cpy_r_r0 = __tmp1;
            """)

    def assert_emit(self, op: Op, expected: str) -> None:
        self.emitter.fragments = []
        self.declarations.fragments = []
        self.env.temp_index = 0
        if isinstance(op, RegisterOp):
            self.env.add_op(op)
        op.accept(self.visitor)
        frags = self.declarations.fragments + self.emitter.fragments
        actual_lines = [line.strip(' ') for line in frags]
        assert all(line.endswith('\n') for line in actual_lines)
        actual_lines = [line.rstrip('\n') for line in actual_lines]
        expected_lines = expected.rstrip().split('\n')
        expected_lines = [line.strip(' ') for line in expected_lines]
        assert_string_arrays_equal(expected_lines,
                                   actual_lines,
                                   msg='Generated code unexpected')

    def assert_emit_binary_op(self, op: str, dest: Value, left: Value,
                              right: Value, expected: str) -> None:
        ops = binary_ops[op]
        for desc in ops:
            if (is_subtype(left.type, desc.arg_types[0])
                    and is_subtype(right.type, desc.arg_types[1])):
                self.assert_emit(PrimitiveOp([left, right], desc, 55),
                                 expected)
                break
        else:
            assert False, 'Could not find matching op'
Example #5
0
class TestFunctionEmitterVisitor(unittest.TestCase):
    def setUp(self) -> None:
        self.env = Environment()
        self.n = self.env.add_local(Var('n'), int_rprimitive)
        self.m = self.env.add_local(Var('m'), int_rprimitive)
        self.k = self.env.add_local(Var('k'), int_rprimitive)
        self.l = self.env.add_local(Var('l'), list_rprimitive)  # noqa
        self.ll = self.env.add_local(Var('ll'), list_rprimitive)
        self.o = self.env.add_local(Var('o'), object_rprimitive)
        self.o2 = self.env.add_local(Var('o2'), object_rprimitive)
        self.d = self.env.add_local(Var('d'), dict_rprimitive)
        self.b = self.env.add_local(Var('b'), bool_rprimitive)
        self.s1 = self.env.add_local(Var('s1'), short_int_rprimitive)
        self.s2 = self.env.add_local(Var('s2'), short_int_rprimitive)
        self.i32 = self.env.add_local(Var('i32'), int32_rprimitive)
        self.i32_1 = self.env.add_local(Var('i32_1'), int32_rprimitive)
        self.i64 = self.env.add_local(Var('i64'), int64_rprimitive)
        self.i64_1 = self.env.add_local(Var('i64_1'), int64_rprimitive)
        self.ptr = self.env.add_local(Var('ptr'), pointer_rprimitive)
        self.t = self.env.add_local(Var('t'),
                                    RTuple([int_rprimitive, bool_rprimitive]))
        self.tt = self.env.add_local(
            Var('tt'),
            RTuple(
                [RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive]))
        ir = ClassIR('A', 'mod')
        ir.attributes = OrderedDict([('x', bool_rprimitive),
                                     ('y', int_rprimitive)])
        compute_vtable(ir)
        ir.mro = [ir]
        self.r = self.env.add_local(Var('r'), RInstance(ir))

        self.context = EmitterContext(NameGenerator([['mod']]))
        self.emitter = Emitter(self.context, self.env)
        self.declarations = Emitter(self.context, self.env)

        const_int_regs = {}  # type: Dict[str, int]
        self.visitor = FunctionEmitterVisitor(self.emitter, self.declarations,
                                              'prog.py', 'prog',
                                              const_int_regs)

    def test_goto(self) -> None:
        self.assert_emit(Goto(BasicBlock(2)), "goto CPyL2;")

    def test_return(self) -> None:
        self.assert_emit(Return(self.m), "return cpy_r_m;")

    def test_load_int(self) -> None:
        self.assert_emit(LoadInt(5), "cpy_r_i0 = 10;")
        self.assert_emit(LoadInt(5, -1, c_int_rprimitive), "cpy_r_i1 = 5;")

    def test_tuple_get(self) -> None:
        self.assert_emit(TupleGet(self.t, 1, 0), 'cpy_r_r0 = cpy_r_t.f1;')

    def test_load_None(self) -> None:
        self.assert_emit(
            LoadAddress(none_object_op.type, none_object_op.src, 0),
            "cpy_r_r0 = (PyObject *)&_Py_NoneStruct;")

    def test_assign_int(self) -> None:
        self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;")

    def test_int_add(self) -> None:
        self.assert_emit_binary_op(
            '+', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);")

    def test_int_sub(self) -> None:
        self.assert_emit_binary_op(
            '-', self.n, self.m, self.k,
            "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);")

    def test_int_neg(self) -> None:
        self.assert_emit(
            CallC(int_neg_op.c_function_name, [self.m], int_neg_op.return_type,
                  int_neg_op.steals, int_neg_op.is_borrowed,
                  int_neg_op.is_borrowed, int_neg_op.error_kind, 55),
            "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);")

    def test_branch(self) -> None:
        self.assert_emit(
            Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR),
            """if (cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)
        b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL_EXPR)
        b.negated = True
        self.assert_emit(
            b, """if (!cpy_r_b) {
                                goto CPyL8;
                            } else
                                goto CPyL9;
                         """)

    def test_call(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([RuntimeArg('m', int_rprimitive)], int_rprimitive))
        self.assert_emit(Call(decl, [self.m], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m);")

    def test_call_two_args(self) -> None:
        decl = FuncDecl(
            'myfn', None, 'mod',
            FuncSignature([
                RuntimeArg('m', int_rprimitive),
                RuntimeArg('n', int_rprimitive)
            ], int_rprimitive))
        self.assert_emit(Call(decl, [self.m, self.k], 55),
                         "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);")

    def test_inc_ref(self) -> None:
        self.assert_emit(IncRef(self.m), "CPyTagged_IncRef(cpy_r_m);")

    def test_dec_ref(self) -> None:
        self.assert_emit(DecRef(self.m), "CPyTagged_DecRef(cpy_r_m);")

    def test_dec_ref_tuple(self) -> None:
        self.assert_emit(DecRef(self.t), 'CPyTagged_DecRef(cpy_r_t.f0);')

    def test_dec_ref_tuple_nested(self) -> None:
        self.assert_emit(DecRef(self.tt), 'CPyTagged_DecRef(cpy_r_tt.f0.f0);')

    def test_list_get_item(self) -> None:
        self.assert_emit(
            CallC(list_get_item_op.c_function_name, [self.m, self.k],
                  list_get_item_op.return_type, list_get_item_op.steals,
                  list_get_item_op.is_borrowed, list_get_item_op.error_kind,
                  55), """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""")

    def test_list_set_item(self) -> None:
        self.assert_emit(
            CallC(list_set_item_op.c_function_name, [self.l, self.n, self.o],
                  list_set_item_op.return_type, list_set_item_op.steals,
                  list_set_item_op.is_borrowed, list_set_item_op.error_kind,
                  55),
            """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""")

    def test_box(self) -> None:
        self.assert_emit(Box(self.n),
                         """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""")

    def test_unbox(self) -> None:
        self.assert_emit(
            Unbox(self.m, int_rprimitive, 55),
            """if (likely(PyLong_Check(cpy_r_m)))
                                cpy_r_r0 = CPyTagged_FromObject(cpy_r_m);
                            else {
                                CPy_TypeError("int", cpy_r_m);
                                cpy_r_r0 = CPY_INT_TAG;
                            }
                         """)

    def test_list_append(self) -> None:
        self.assert_emit(
            CallC(list_append_op.c_function_name, [self.l, self.o],
                  list_append_op.return_type, list_append_op.steals,
                  list_append_op.is_borrowed, list_append_op.error_kind, 1),
            """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o);""")

    def test_get_attr(self) -> None:
        self.assert_emit(
            GetAttr(self.r, 'y', 1),
            """cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
               if (unlikely(((mod___AObject *)cpy_r_r)->_y == CPY_INT_TAG)) {
                   PyErr_SetString(PyExc_AttributeError, "attribute 'y' of 'A' undefined");
               } else {
                   CPyTagged_IncRef(((mod___AObject *)cpy_r_r)->_y);
               }
            """)

    def test_set_attr(self) -> None:
        self.assert_emit(
            SetAttr(self.r, 'y', self.m, 1),
            """if (((mod___AObject *)cpy_r_r)->_y != CPY_INT_TAG) {
                   CPyTagged_DecRef(((mod___AObject *)cpy_r_r)->_y);
               }
               ((mod___AObject *)cpy_r_r)->_y = cpy_r_m;
               cpy_r_r0 = 1;
            """)

    def test_dict_get_item(self) -> None:
        self.assert_emit(
            CallC(dict_get_item_op.c_function_name, [self.d, self.o2],
                  dict_get_item_op.return_type, dict_get_item_op.steals,
                  dict_get_item_op.is_borrowed, dict_get_item_op.error_kind,
                  1), """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""")

    def test_dict_set_item(self) -> None:
        self.assert_emit(
            CallC(dict_set_item_op.c_function_name, [self.d, self.o, self.o2],
                  dict_set_item_op.return_type, dict_set_item_op.steals,
                  dict_set_item_op.is_borrowed, dict_set_item_op.error_kind,
                  1),
            """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2);""")

    def test_dict_update(self) -> None:
        self.assert_emit(
            CallC(dict_update_op.c_function_name, [self.d, self.o],
                  dict_update_op.return_type, dict_update_op.steals,
                  dict_update_op.is_borrowed, dict_update_op.error_kind, 1),
            """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o);""")

    def test_new_dict(self) -> None:
        self.assert_emit(
            CallC(dict_new_op.c_function_name, [], dict_new_op.return_type,
                  dict_new_op.steals, dict_new_op.is_borrowed,
                  dict_new_op.error_kind, 1), """cpy_r_r0 = PyDict_New();""")

    def test_dict_contains(self) -> None:
        self.assert_emit_binary_op(
            'in', self.b, self.o, self.d,
            """cpy_r_r0 = PyDict_Contains(cpy_r_d, cpy_r_o);""")

    def test_binary_int_op(self) -> None:
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.ADD, 1),
            """cpy_r_r0 = cpy_r_s1 + cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.SUB, 1),
            """cpy_r_r00 = cpy_r_s1 - cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.MUL, 1),
            """cpy_r_r01 = cpy_r_s1 * cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.DIV, 1),
            """cpy_r_r02 = cpy_r_s1 / cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.MOD, 1),
            """cpy_r_r03 = cpy_r_s1 % cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.AND, 1),
            """cpy_r_r04 = cpy_r_s1 & cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2, BinaryIntOp.OR,
                        1), """cpy_r_r05 = cpy_r_s1 | cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.XOR, 1),
            """cpy_r_r06 = cpy_r_s1 ^ cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.LEFT_SHIFT, 1),
            """cpy_r_r07 = cpy_r_s1 << cpy_r_s2;""")
        self.assert_emit(
            BinaryIntOp(short_int_rprimitive, self.s1, self.s2,
                        BinaryIntOp.RIGHT_SHIFT, 1),
            """cpy_r_r08 = cpy_r_s1 >> cpy_r_s2;""")

    def test_comparison_op(self) -> None:
        # signed
        self.assert_emit(
            ComparisonOp(self.s1, self.s2, ComparisonOp.SLT, 1),
            """cpy_r_r0 = (Py_ssize_t)cpy_r_s1 < (Py_ssize_t)cpy_r_s2;""")
        self.assert_emit(
            ComparisonOp(self.i32, self.i32_1, ComparisonOp.SLT, 1),
            """cpy_r_r00 = cpy_r_i32 < cpy_r_i32_1;""")
        self.assert_emit(
            ComparisonOp(self.i64, self.i64_1, ComparisonOp.SLT, 1),
            """cpy_r_r01 = cpy_r_i64 < cpy_r_i64_1;""")
        # unsigned
        self.assert_emit(ComparisonOp(self.s1, self.s2, ComparisonOp.ULT, 1),
                         """cpy_r_r02 = cpy_r_s1 < cpy_r_s2;""")
        self.assert_emit(
            ComparisonOp(self.i32, self.i32_1, ComparisonOp.ULT, 1),
            """cpy_r_r03 = (uint32_t)cpy_r_i32 < (uint32_t)cpy_r_i32_1;""")
        self.assert_emit(
            ComparisonOp(self.i64, self.i64_1, ComparisonOp.ULT, 1),
            """cpy_r_r04 = (uint64_t)cpy_r_i64 < (uint64_t)cpy_r_i64_1;""")

        # object type
        self.assert_emit(ComparisonOp(self.o, self.o2, ComparisonOp.EQ, 1),
                         """cpy_r_r05 = cpy_r_o == cpy_r_o2;""")
        self.assert_emit(ComparisonOp(self.o, self.o2, ComparisonOp.NEQ, 1),
                         """cpy_r_r06 = cpy_r_o != cpy_r_o2;""")

    def test_load_mem(self) -> None:
        self.assert_emit(LoadMem(bool_rprimitive, self.ptr, None),
                         """cpy_r_r0 = *(char *)cpy_r_ptr;""")
        self.assert_emit(LoadMem(bool_rprimitive, self.ptr, self.s1),
                         """cpy_r_r00 = *(char *)cpy_r_ptr;""")

    def test_set_mem(self) -> None:
        self.assert_emit(SetMem(bool_rprimitive, self.ptr, self.b, None),
                         """*(char *)cpy_r_ptr = cpy_r_b;""")

    def test_get_element_ptr(self) -> None:
        r = RStruct("Foo", ["b", "i32", "i64"],
                    [bool_rprimitive, int32_rprimitive, int64_rprimitive])
        self.assert_emit(GetElementPtr(self.o, r, "b"),
                         """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->b;""")
        self.assert_emit(GetElementPtr(self.o, r, "i32"),
                         """cpy_r_r00 = (CPyPtr)&((Foo *)cpy_r_o)->i32;""")
        self.assert_emit(GetElementPtr(self.o, r, "i64"),
                         """cpy_r_r01 = (CPyPtr)&((Foo *)cpy_r_o)->i64;""")

    def test_load_address(self) -> None:
        self.assert_emit(LoadAddress(object_rprimitive, "PyDict_Type"),
                         """cpy_r_r0 = (PyObject *)&PyDict_Type;""")

    def assert_emit(self, op: Op, expected: str) -> None:
        self.emitter.fragments = []
        self.declarations.fragments = []
        self.env.temp_index = 0
        if isinstance(op, RegisterOp):
            self.env.add_op(op)
        op.accept(self.visitor)
        frags = self.declarations.fragments + self.emitter.fragments
        actual_lines = [line.strip(' ') for line in frags]
        assert all(line.endswith('\n') for line in actual_lines)
        actual_lines = [line.rstrip('\n') for line in actual_lines]
        expected_lines = expected.rstrip().split('\n')
        expected_lines = [line.strip(' ') for line in expected_lines]
        assert_string_arrays_equal(expected_lines,
                                   actual_lines,
                                   msg='Generated code unexpected')

    def assert_emit_binary_op(self, op: str, dest: Value, left: Value,
                              right: Value, expected: str) -> None:
        # TODO: merge this
        if op in c_binary_ops:
            c_ops = c_binary_ops[op]
            for c_desc in c_ops:
                if (is_subtype(left.type, c_desc.arg_types[0])
                        and is_subtype(right.type, c_desc.arg_types[1])):
                    args = [left, right]
                    if c_desc.ordering is not None:
                        args = [args[i] for i in c_desc.ordering]
                    self.assert_emit(
                        CallC(c_desc.c_function_name, args, c_desc.return_type,
                              c_desc.steals, c_desc.is_borrowed,
                              c_desc.error_kind, 55), expected)
                    return
        else:
            assert False, 'Could not find matching op'
Example #6
0
def split_blocks_at_errors(blocks: List[BasicBlock],
                           default_error_handler: BasicBlock,
                           func_name: Optional[str],
                           env: Environment) -> List[BasicBlock]:
    new_blocks = []  # type: List[BasicBlock]

    # First split blocks on ops that may raise.
    for block in blocks:
        ops = block.ops
        block.ops = []
        cur_block = block
        new_blocks.append(cur_block)

        # If the block has an error handler specified, use it. Otherwise
        # fall back to the default.
        error_label = block.error_handler or default_error_handler
        block.error_handler = None

        for op in ops:
            target = op
            cur_block.ops.append(op)
            if isinstance(op, RegisterOp) and op.error_kind != ERR_NEVER:
                # Split
                new_block = BasicBlock()
                new_blocks.append(new_block)

                if op.error_kind == ERR_MAGIC:
                    # Op returns an error value on error that depends on result RType.
                    variant = Branch.IS_ERROR
                    negated = False
                elif op.error_kind == ERR_FALSE:
                    # Op returns a C false value on error.
                    variant = Branch.BOOL
                    negated = True
                elif op.error_kind == ERR_ALWAYS:
                    variant = Branch.BOOL
                    negated = True
                    # this is a hack to represent the always fail
                    # semantics, using a temporary bool with value false
                    tmp = LoadInt(0, rtype=bool_rprimitive)
                    cur_block.ops.append(tmp)
                    env.add_op(tmp)
                    target = tmp
                else:
                    assert False, 'unknown error kind %d' % op.error_kind

                # Void ops can't generate errors since error is always
                # indicated by a special value stored in a register.
                if op.error_kind != ERR_ALWAYS:
                    assert not op.is_void, "void op generating errors?"

                branch = Branch(target,
                                true_label=error_label,
                                false_label=new_block,
                                op=variant,
                                line=op.line)
                branch.negated = negated
                if op.line != NO_TRACEBACK_LINE_NO and func_name is not None:
                    branch.traceback_entry = (func_name, op.line)
                cur_block.ops.append(branch)
                cur_block = new_block

    return new_blocks