def transform_block(block: BasicBlock, pre_live: AnalysisDict[Register], post_live: AnalysisDict[Register], pre_borrow: AnalysisDict[Register], env: Environment) -> None: old_ops = block.ops ops = [] # type: List[Op] for i, op in enumerate(old_ops): key = (block.label, i) if isinstance(op, (Assign, Cast, Box)): # These operations just copy/steal a reference and don't create new # references. if op.src in post_live[key] or op.src in pre_borrow[key]: ops.append(IncRef(op.src, env.types[op.src])) if (op.dest not in pre_borrow[key] and op.dest in pre_live[key]): ops.append(DecRef(op.dest, env.types[op.dest])) ops.append(op) if op.dest not in post_live[key]: ops.append(DecRef(op.dest, env.types[op.dest])) elif isinstance(op, RegisterOp): # These operations construct a new reference. tmp_reg = None # type: Optional[Register] if (op.dest not in pre_borrow[key] and op.dest in pre_live[key]): if op.dest not in op.sources(): ops.append(DecRef(op.dest, env.types[op.dest])) else: tmp_reg = env.add_temp(env.types[op.dest]) ops.append(Assign(tmp_reg, op.dest)) ops.append(op) for src in op.unique_sources(): # Decrement source that won't be live afterwards. if src not in post_live[key] and src not in pre_borrow[key]: if src != op.dest: ops.append(DecRef(src, env.types[src])) if op.dest is not None and op.dest not in post_live[key]: ops.append(DecRef(op.dest, env.types[op.dest])) if tmp_reg is not None: ops.append(DecRef(tmp_reg, env.types[tmp_reg])) elif isinstance(op, Return) and op.reg in pre_borrow[key]: # The return op returns a new reference. ops.append(IncRef(op.reg, env.types[op.reg])) ops.append(op) else: ops.append(op) block.ops = ops
class TestGenerateFunction(unittest.TestCase): def setUp(self) -> None: self.var = Var('arg') self.arg = RuntimeArg('arg', IntRType()) self.env = Environment() self.reg = self.env.add_local(self.var, IntRType()) self.block = BasicBlock(Label(0)) def test_simple(self) -> None: self.block.ops.append(Return(self.reg)) fn = FuncIR('myfunc', [self.arg], IntRType(), [self.block], self.env) emitter = Emitter(EmitterContext()) generate_native_function(fn, emitter) result = emitter.fragments assert_string_arrays_equal([ 'static CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', 'CPyL0: ;\n', ' return cpy_r_arg;\n', '}\n', ], result, msg='Generated code invalid') def test_register(self) -> None: self.temp = self.env.add_temp(IntRType()) self.block.ops.append(LoadInt(self.temp, 5)) fn = FuncIR('myfunc', [self.arg], ListRType(), [self.block], self.env) emitter = Emitter(EmitterContext()) generate_native_function(fn, emitter) result = emitter.fragments assert_string_arrays_equal([ 'static PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', ' CPyTagged cpy_r_r0;\n', 'CPyL0: ;\n', ' cpy_r_r0 = 10;\n', '}\n', ], result, msg='Generated code invalid')
class LowLevelIRBuilder: def __init__( self, current_module: str, mapper: Mapper, ) -> None: self.current_module = current_module self.mapper = mapper self.environment = Environment() self.blocks = [] # type: List[BasicBlock] # Stack of except handler entry blocks self.error_handlers = [None] # type: List[Optional[BasicBlock]] def add(self, op: Op) -> Value: assert not self.blocks[-1].terminated, "Can't add to finished block" self.blocks[-1].ops.append(op) if isinstance(op, RegisterOp): self.environment.add_op(op) return op def goto(self, target: BasicBlock) -> None: if not self.blocks[-1].terminated: self.add(Goto(target)) def activate_block(self, block: BasicBlock) -> None: if self.blocks: assert self.blocks[-1].terminated block.error_handler = self.error_handlers[-1] self.blocks.append(block) def goto_and_activate(self, block: BasicBlock) -> None: self.goto(block) self.activate_block(block) def push_error_handler(self, handler: Optional[BasicBlock]) -> None: self.error_handlers.append(handler) def pop_error_handler(self) -> Optional[BasicBlock]: return self.error_handlers.pop() ## def get_native_type(self, cls: ClassIR) -> Value: fullname = '%s.%s' % (cls.module_name, cls.name) return self.load_native_type_object(fullname) def primitive_op(self, desc: OpDescription, args: List[Value], line: int) -> Value: assert desc.result_type is not None coerced = [] for i, arg in enumerate(args): formal_type = self.op_arg_type(desc, i) arg = self.coerce(arg, formal_type, line) coerced.append(arg) target = self.add(PrimitiveOp(coerced, desc, line)) return target def alloc_temp(self, type: RType) -> Register: return self.environment.add_temp(type) def op_arg_type(self, desc: OpDescription, n: int) -> RType: if n >= len(desc.arg_types): assert desc.is_var_arg return desc.arg_types[-1] return desc.arg_types[n] def box(self, src: Value) -> Value: if src.type.is_unboxed: return self.add(Box(src)) else: return src def unbox_or_cast(self, src: Value, target_type: RType, line: int) -> Value: if target_type.is_unboxed: return self.add(Unbox(src, target_type, line)) else: return self.add(Cast(src, target_type, line)) def coerce(self, src: Value, target_type: RType, line: int, force: bool = False) -> Value: """Generate a coercion/cast from one type to other (only if needed). For example, int -> object boxes the source int; int -> int emits nothing; object -> int unboxes the object. All conversions preserve object value. If force is true, always generate an op (even if it is just an assignment) so that the result will have exactly target_type as the type. Returns the register with the converted value (may be same as src). """ if src.type.is_unboxed and not target_type.is_unboxed: return self.box(src) if ((src.type.is_unboxed and target_type.is_unboxed) and not is_runtime_subtype(src.type, target_type)): # To go from one unboxed type to another, we go through a boxed # in-between value, for simplicity. tmp = self.box(src) return self.unbox_or_cast(tmp, target_type, line) if ((not src.type.is_unboxed and target_type.is_unboxed) or not is_subtype(src.type, target_type)): return self.unbox_or_cast(src, target_type, line) elif force: tmp = self.alloc_temp(target_type) self.add(Assign(tmp, src)) return tmp return src def none(self) -> Value: return self.add(PrimitiveOp([], none_op, line=-1)) def none_object(self) -> Value: return self.add(PrimitiveOp([], none_object_op, line=-1)) def get_attr(self, obj: Value, attr: str, result_type: RType, line: int) -> Value: if (isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class and obj.type.class_ir.has_attr(attr)): return self.add(GetAttr(obj, attr, line)) elif isinstance(obj.type, RUnion): return self.union_get_attr(obj, obj.type, attr, result_type, line) else: return self.py_get_attr(obj, attr, line) def union_get_attr(self, obj: Value, rtype: RUnion, attr: str, result_type: RType, line: int) -> Value: def get_item_attr(value: Value) -> Value: return self.get_attr(value, attr, result_type, line) return self.decompose_union_helper(obj, rtype, result_type, get_item_attr, line) def decompose_union_helper(self, obj: Value, rtype: RUnion, result_type: RType, process_item: Callable[[Value], Value], line: int) -> Value: """Generate isinstance() + specialized operations for union items. Say, for Union[A, B] generate ops resembling this (pseudocode): if isinstance(obj, A): result = <result of process_item(cast(A, obj)> else: result = <result of process_item(cast(B, obj)> Args: obj: value with a union type rtype: the union type result_type: result of the operation process_item: callback to generate op for a single union item (arg is coerced to union item type) line: line number """ # TODO: Optimize cases where a single operation can handle multiple union items # (say a method is implemented in a common base class) fast_items = [] rest_items = [] for item in rtype.items: if isinstance(item, RInstance): fast_items.append(item) else: # For everything but RInstance we fall back to C API rest_items.append(item) exit_block = BasicBlock() result = self.alloc_temp(result_type) for i, item in enumerate(fast_items): more_types = i < len(fast_items) - 1 or rest_items if more_types: # We are not at the final item so we need one more branch op = self.isinstance_native(obj, item.class_ir, line) true_block, false_block = BasicBlock(), BasicBlock() self.add_bool_branch(op, true_block, false_block) self.activate_block(true_block) coerced = self.coerce(obj, item, line) temp = process_item(coerced) temp2 = self.coerce(temp, result_type, line) self.add(Assign(result, temp2)) self.goto(exit_block) if more_types: self.activate_block(false_block) if rest_items: # For everything else we use generic operation. Use force=True to drop the # union type. coerced = self.coerce(obj, object_rprimitive, line, force=True) temp = process_item(coerced) temp2 = self.coerce(temp, result_type, line) self.add(Assign(result, temp2)) self.goto(exit_block) self.activate_block(exit_block) return result def isinstance_helper(self, obj: Value, class_irs: List[ClassIR], line: int) -> Value: """Fast path for isinstance() that checks against a list of native classes.""" if not class_irs: return self.primitive_op(false_op, [], line) ret = self.isinstance_native(obj, class_irs[0], line) for class_ir in class_irs[1:]: def other() -> Value: return self.isinstance_native(obj, class_ir, line) ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line) return ret def isinstance_native(self, obj: Value, class_ir: ClassIR, line: int) -> Value: """Fast isinstance() check for a native class. If there three or less concrete (non-trait) classes among the class and all its children, use even faster type comparison checks `type(obj) is typ`. """ concrete = all_concrete_classes(class_ir) if concrete is None or len( concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1: return self.primitive_op(fast_isinstance_op, [obj, self.get_native_type(class_ir)], line) if not concrete: # There can't be any concrete instance that matches this. return self.primitive_op(false_op, [], line) type_obj = self.get_native_type(concrete[0]) ret = self.primitive_op(type_is_op, [obj, type_obj], line) for c in concrete[1:]: def other() -> Value: return self.primitive_op(type_is_op, [obj, self.get_native_type(c)], line) ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line) return ret def py_get_attr(self, obj: Value, attr: str, line: int) -> Value: key = self.load_static_unicode(attr) return self.add(PrimitiveOp([obj, key], py_getattr_op, line)) def py_call(self, function: Value, arg_values: List[Value], line: int, arg_kinds: Optional[List[int]] = None, arg_names: Optional[Sequence[Optional[str]]] = None) -> Value: """Use py_call_op or py_call_with_kwargs_op for function call.""" # If all arguments are positional, we can use py_call_op. if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds): return self.primitive_op(py_call_op, [function] + arg_values, line) # Otherwise fallback to py_call_with_kwargs_op. assert arg_names is not None pos_arg_values = [] kw_arg_key_value_pairs = [] # type: List[DictEntry] star_arg_values = [] for value, kind, name in zip(arg_values, arg_kinds, arg_names): if kind == ARG_POS: pos_arg_values.append(value) elif kind == ARG_NAMED: assert name is not None key = self.load_static_unicode(name) kw_arg_key_value_pairs.append((key, value)) elif kind == ARG_STAR: star_arg_values.append(value) elif kind == ARG_STAR2: # NOTE: mypy currently only supports a single ** arg, but python supports multiple. # This code supports multiple primarily to make the logic easier to follow. kw_arg_key_value_pairs.append((None, value)) else: assert False, ("Argument kind should not be possible:", kind) if len(star_arg_values) == 0: # We can directly construct a tuple if there are no star args. pos_args_tuple = self.primitive_op(new_tuple_op, pos_arg_values, line) else: # Otherwise we construct a list and call extend it with the star args, since tuples # don't have an extend method. pos_args_list = self.primitive_op(new_list_op, pos_arg_values, line) for star_arg_value in star_arg_values: self.primitive_op(list_extend_op, [pos_args_list, star_arg_value], line) pos_args_tuple = self.primitive_op(list_tuple_op, [pos_args_list], line) kw_args_dict = self.make_dict(kw_arg_key_value_pairs, line) return self.primitive_op(py_call_with_kwargs_op, [function, pos_args_tuple, kw_args_dict], line) def py_method_call(self, obj: Value, method_name: str, arg_values: List[Value], line: int, arg_kinds: Optional[List[int]], arg_names: Optional[Sequence[Optional[str]]]) -> Value: if (arg_kinds is None) or all(kind == ARG_POS for kind in arg_kinds): method_name_reg = self.load_static_unicode(method_name) return self.primitive_op(py_method_call_op, [obj, method_name_reg] + arg_values, line) else: method = self.py_get_attr(obj, method_name, line) return self.py_call(method, arg_values, line, arg_kinds=arg_kinds, arg_names=arg_names) def call(self, decl: FuncDecl, args: Sequence[Value], arg_kinds: List[int], arg_names: Sequence[Optional[str]], line: int) -> Value: # Normalize args to positionals. args = self.native_args_to_positional(args, arg_kinds, arg_names, decl.sig, line) return self.add(Call(decl, args, line)) def native_args_to_positional(self, args: Sequence[Value], arg_kinds: List[int], arg_names: Sequence[Optional[str]], sig: FuncSignature, line: int) -> List[Value]: """Prepare arguments for a native call. Given args/kinds/names and a target signature for a native call, map keyword arguments to their appropriate place in the argument list, fill in error values for unspecified default arguments, package arguments that will go into *args/**kwargs into a tuple/dict, and coerce arguments to the appropriate type. """ sig_arg_kinds = [arg.kind for arg in sig.args] sig_arg_names = [arg.name for arg in sig.args] formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, sig_arg_kinds, sig_arg_names, lambda n: AnyType(TypeOfAny.special_form)) # Flatten out the arguments, loading error values for default # arguments, constructing tuples/dicts for star args, and # coercing everything to the expected type. output_args = [] for lst, arg in zip(formal_to_actual, sig.args): output_arg = None if arg.kind == ARG_STAR: output_arg = self.primitive_op(new_tuple_op, [args[i] for i in lst], line) elif arg.kind == ARG_STAR2: dict_entries = [ (self.load_static_unicode(cast(str, arg_names[i])), args[i]) for i in lst ] output_arg = self.make_dict(dict_entries, line) elif not lst: output_arg = self.add( LoadErrorValue(arg.type, is_borrowed=True)) else: output_arg = args[lst[0]] output_args.append(self.coerce(output_arg, arg.type, line)) return output_args def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value: result = None # type: Union[Value, None] initial_items = [] # type: List[Value] for key, value in key_value_pairs: if key is not None: # key:value if result is None: initial_items.extend((key, value)) continue self.translate_special_method_call(result, '__setitem__', [key, value], result_type=None, line=line) else: # **value if result is None: result = self.primitive_op(new_dict_op, initial_items, line) self.primitive_op(dict_update_in_display_op, [result, value], line=line) if result is None: result = self.primitive_op(new_dict_op, initial_items, line) return result # Loading stuff def literal_static_name( self, value: Union[int, float, complex, str, bytes]) -> str: return self.mapper.literal_static_name(self.current_module, value) def load_static_int(self, value: int) -> Value: """Loads a static integer Python 'int' object into a register.""" if abs(value) > MAX_LITERAL_SHORT_INT: static_symbol = self.literal_static_name(value) return self.add( LoadStatic(int_rprimitive, static_symbol, ann=value)) else: return self.add(LoadInt(value)) def load_static_float(self, value: float) -> Value: """Loads a static float value into a register.""" static_symbol = self.literal_static_name(value) return self.add(LoadStatic(float_rprimitive, static_symbol, ann=value)) def load_static_bytes(self, value: bytes) -> Value: """Loads a static bytes value into a register.""" static_symbol = self.literal_static_name(value) return self.add(LoadStatic(object_rprimitive, static_symbol, ann=value)) def load_static_complex(self, value: complex) -> Value: """Loads a static complex value into a register.""" static_symbol = self.literal_static_name(value) return self.add(LoadStatic(object_rprimitive, static_symbol, ann=value)) def load_static_unicode(self, value: str) -> Value: """Loads a static unicode value into a register. This is useful for more than just unicode literals; for example, method calls also require a PyObject * form for the name of the method. """ static_symbol = self.literal_static_name(value) return self.add(LoadStatic(str_rprimitive, static_symbol, ann=value)) def load_module(self, name: str) -> Value: return self.add( LoadStatic(object_rprimitive, name, namespace=NAMESPACE_MODULE)) def load_native_type_object(self, fullname: str) -> Value: module, name = fullname.rsplit('.', 1) return self.add( LoadStatic(object_rprimitive, name, module, NAMESPACE_TYPE)) def matching_primitive_op( self, candidates: List[OpDescription], args: List[Value], line: int, result_type: Optional[RType] = None) -> Optional[Value]: # Find the highest-priority primitive op that matches. matching = None # type: Optional[OpDescription] for desc in candidates: if len(desc.arg_types) != len(args): continue if all( is_subtype(actual.type, formal) for actual, formal in zip(args, desc.arg_types)): if matching: assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % ( matching, desc) if desc.priority > matching.priority: matching = desc else: matching = desc if matching: target = self.primitive_op(matching, args, line) if result_type and not is_runtime_subtype(target.type, result_type): if is_none_rprimitive(result_type): # Special case None return. The actual result may actually be a bool # and so we can't just coerce it. target = self.none() else: target = self.coerce(target, result_type, line) return target return None def binary_op(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value: # Special case == and != when we can resolve the method call statically. value = None if expr_op in ('==', '!='): value = self.translate_eq_cmp(lreg, rreg, expr_op, line) if value is not None: return value ops = binary_ops.get(expr_op, []) target = self.matching_primitive_op(ops, [lreg, rreg], line) assert target, 'Unsupported binary operation: %s' % expr_op return target def unary_op(self, lreg: Value, expr_op: str, line: int) -> Value: ops = unary_ops.get(expr_op, []) target = self.matching_primitive_op(ops, [lreg], line) assert target, 'Unsupported unary operation: %s' % expr_op return target def shortcircuit_helper(self, op: str, expr_type: RType, left: Callable[[], Value], right: Callable[[], Value], line: int) -> Value: # Having actual Phi nodes would be really nice here! target = self.alloc_temp(expr_type) # left_body takes the value of the left side, right_body the right left_body, right_body, next = BasicBlock(), BasicBlock(), BasicBlock() # true_body is taken if the left is true, false_body if it is false. # For 'and' the value is the right side if the left is true, and for 'or' # it is the right side if the left is false. true_body, false_body = ((right_body, left_body) if op == 'and' else (left_body, right_body)) left_value = left() self.add_bool_branch(left_value, true_body, false_body) self.activate_block(left_body) left_coerced = self.coerce(left_value, expr_type, line) self.add(Assign(target, left_coerced)) self.goto(next) self.activate_block(right_body) right_value = right() right_coerced = self.coerce(right_value, expr_type, line) self.add(Assign(target, right_coerced)) self.goto(next) self.activate_block(next) return target def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None: if is_runtime_subtype(value.type, int_rprimitive): zero = self.add(LoadInt(0)) value = self.binary_op(value, zero, '!=', value.line) elif is_same_type(value.type, list_rprimitive): length = self.primitive_op(list_len_op, [value], value.line) zero = self.add(LoadInt(0)) value = self.binary_op(length, zero, '!=', value.line) elif (isinstance(value.type, RInstance) and value.type.class_ir.is_ext_class and value.type.class_ir.has_method('__bool__')): # Directly call the __bool__ method on classes that have it. value = self.gen_method_call(value, '__bool__', [], bool_rprimitive, value.line) else: value_type = optional_value_type(value.type) if value_type is not None: is_none = self.binary_op(value, self.none_object(), 'is not', value.line) branch = Branch(is_none, true, false, Branch.BOOL_EXPR) self.add(branch) always_truthy = False if isinstance(value_type, RInstance): # check whether X.__bool__ is always just the default (object.__bool__) if (not value_type.class_ir.has_method('__bool__') and value_type.class_ir.is_method_final('__bool__')): always_truthy = True if not always_truthy: # Optional[X] where X may be falsey and requires a check branch.true = BasicBlock() self.activate_block(branch.true) # unbox_or_cast instead of coerce because we want the # type to change even if it is a subtype. remaining = self.unbox_or_cast(value, value_type, value.line) self.add_bool_branch(remaining, true, false) return elif not is_same_type(value.type, bool_rprimitive): value = self.primitive_op(bool_op, [value], value.line) self.add(Branch(value, true, false, Branch.BOOL_EXPR)) def translate_special_method_call(self, base_reg: Value, name: str, args: List[Value], result_type: Optional[RType], line: int) -> Optional[Value]: """Translate a method call which is handled nongenerically. These are special in the sense that we have code generated specifically for them. They tend to be method calls which have equivalents in C that are more direct than calling with the PyObject api. Return None if no translation found; otherwise return the target register. """ ops = method_ops.get(name, []) return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type) def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Optional[Value]: ltype = lreg.type rtype = rreg.type if not (isinstance(ltype, RInstance) and ltype == rtype): return None class_ir = ltype.class_ir # Check whether any subclasses of the operand redefines __eq__ # or it might be redefined in a Python parent class or by # dataclasses cmp_varies_at_runtime = (not class_ir.is_method_final('__eq__') or not class_ir.is_method_final('__ne__') or class_ir.inherits_python or class_ir.is_augmented) if cmp_varies_at_runtime: # We might need to call left.__eq__(right) or right.__eq__(left) # depending on which is the more specific type. return None if not class_ir.has_method('__eq__'): # There's no __eq__ defined, so just use object identity. identity_ref_op = 'is' if expr_op == '==' else 'is not' return self.binary_op(lreg, rreg, identity_ref_op, line) return self.gen_method_call(lreg, op_methods[expr_op], [rreg], ltype, line) def gen_method_call( self, base: Value, name: str, arg_values: List[Value], result_type: Optional[RType], line: int, arg_kinds: Optional[List[int]] = None, arg_names: Optional[List[Optional[str]]] = None) -> Value: # If arg_kinds contains values other than arg_pos and arg_named, then fallback to # Python method call. if (arg_kinds is not None and not all(kind in (ARG_POS, ARG_NAMED) for kind in arg_kinds)): return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names) # If the base type is one of ours, do a MethodCall if (isinstance(base.type, RInstance) and base.type.class_ir.is_ext_class and not base.type.class_ir.builtin_base): if base.type.class_ir.has_method(name): decl = base.type.class_ir.method_decl(name) if arg_kinds is None: assert arg_names is None, "arg_kinds not present but arg_names is" arg_kinds = [ARG_POS for _ in arg_values] arg_names = [None for _ in arg_values] else: assert arg_names is not None, "arg_kinds present but arg_names is not" # Normalize args to positionals. assert decl.bound_sig arg_values = self.native_args_to_positional( arg_values, arg_kinds, arg_names, decl.bound_sig, line) return self.add(MethodCall(base, name, arg_values, line)) elif base.type.class_ir.has_attr(name): function = self.add(GetAttr(base, name, line)) return self.py_call(function, arg_values, line, arg_kinds=arg_kinds, arg_names=arg_names) elif isinstance(base.type, RUnion): return self.union_method_call(base, base.type, name, arg_values, result_type, line, arg_kinds, arg_names) # Try to do a special-cased method call if not arg_kinds or arg_kinds == [ARG_POS] * len(arg_values): target = self.translate_special_method_call( base, name, arg_values, result_type, line) if target: return target # Fall back to Python method call return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names) def union_method_call(self, base: Value, obj_type: RUnion, name: str, arg_values: List[Value], return_rtype: Optional[RType], line: int, arg_kinds: Optional[List[int]], arg_names: Optional[List[Optional[str]]]) -> Value: # Union method call needs a return_rtype for the type of the output register. # If we don't have one, use object_rprimitive. return_rtype = return_rtype or object_rprimitive def call_union_item(value: Value) -> Value: return self.gen_method_call(value, name, arg_values, return_rtype, line, arg_kinds, arg_names) return self.decompose_union_helper(base, obj_type, return_rtype, call_union_item, line)
class IRBuilder(NodeVisitor[Register]): def __init__(self, types: Dict[Expression, Type], mapper: Mapper) -> None: self.types = types self.environment = Environment() self.environments = [self.environment] self.blocks = [] # type: List[List[BasicBlock]] self.functions = [] # type: List[FuncIR] self.classes = [] # type: List[ClassIR] self.targets = [] # type: List[Register] # These lists operate as stack frames for loops. Each loop adds a new # frame (i.e. adds a new empty list [] to the outermost list). Each # break or continue is inserted within that frame as they are visited # and at the end of the loop the stack is popped and any break/continue # gotos have their targets rewritten to the next basic block. self.break_gotos = [] # type: List[List[Goto]] self.continue_gotos = [] # type: List[List[Goto]] self.mapper = mapper self.imports = [] # type: List[str] self.current_module_name = None # type: Optional[str] def visit_mypy_file(self, mypyfile: MypyFile) -> Register: if mypyfile.fullname() in ('typing', 'abc'): # These module are special; their contents are currently all # built-in primitives. return INVALID_REGISTER # First pass: Build ClassIRs and TypeInfo-to-ClassIR mapping. for node in mypyfile.defs: if isinstance(node, ClassDef): self.prepare_class_def(node) # Second pass: Generate ops. self.current_module_name = mypyfile.fullname() for node in mypyfile.defs: node.accept(self) return INVALID_REGISTER def prepare_class_def(self, cdef: ClassDef) -> None: ir = ClassIR(cdef.name, []) # Populate attributes later in visit_class_def self.classes.append(ir) self.mapper.type_to_ir[cdef.info] = ir def visit_class_def(self, cdef: ClassDef) -> Register: attributes = [] for name, node in cdef.info.names.items(): if isinstance(node.node, Var): attributes.append((name, self.type_to_rtype(node.node.type))) ir = self.mapper.type_to_ir[cdef.info] ir.attributes = attributes return INVALID_REGISTER def visit_import(self, node: Import) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" for node_id, _ in node.ids: self.imports.append(node_id) return INVALID_REGISTER def visit_import_from(self, node: ImportFrom) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" self.imports.append(node.id) return INVALID_REGISTER def visit_import_all(self, node: ImportAll) -> Register: if node.is_unreachable or node.is_mypy_only: pass if not node.is_top_level: assert False, "non-toplevel imports not supported" self.imports.append(node.id) return INVALID_REGISTER def visit_func_def(self, fdef: FuncDef) -> Register: self.enter() for arg in fdef.arguments: self.environment.add_local(arg.variable, self.type_to_rtype(arg.variable.type)) fdef.body.accept(self) ret_type = self.convert_return_type(fdef) if ret_type.name == 'None': self.add_implicit_return() else: self.add_implicit_unreachable() blocks, env = self.leave() args = self.convert_args(fdef) func = FuncIR(fdef.name(), args, ret_type, blocks, env) self.functions.append(func) return INVALID_REGISTER def convert_args(self, fdef: FuncDef) -> List[RuntimeArg]: assert isinstance(fdef.type, CallableType) ann = fdef.type return [ RuntimeArg(arg.variable.name(), self.type_to_rtype(ann.arg_types[i])) for i, arg in enumerate(fdef.arguments) ] def convert_return_type(self, fdef: FuncDef) -> RType: assert isinstance(fdef.type, CallableType) return self.type_to_rtype(fdef.type.ret_type) def add_implicit_return(self) -> None: block = self.blocks[-1][-1] if not block.ops or not isinstance(block.ops[-1], Return): retval = self.environment.add_temp(NoneRType()) self.add(PrimitiveOp(retval, PrimitiveOp.NONE)) self.add(Return(retval)) def add_implicit_unreachable(self) -> None: block = self.blocks[-1][-1] if not block.ops or not isinstance(block.ops[-1], Return): self.add(Unreachable()) def visit_block(self, block: Block) -> Register: for stmt in block.body: stmt.accept(self) return INVALID_REGISTER def visit_expression_stmt(self, stmt: ExpressionStmt) -> Register: self.accept(stmt.expr) return INVALID_REGISTER def visit_return_stmt(self, stmt: ReturnStmt) -> Register: if stmt.expr: retval = self.accept(stmt.expr) else: retval = self.environment.add_temp(NoneRType()) self.add(PrimitiveOp(retval, PrimitiveOp.NONE)) self.add(Return(retval)) return INVALID_REGISTER def visit_assignment_stmt(self, stmt: AssignmentStmt) -> Register: assert len(stmt.lvalues) == 1 lvalue = stmt.lvalues[0] if stmt.type: lvalue_type = self.type_to_rtype(stmt.type) else: if isinstance(lvalue, IndexExpr): # TODO: This won't be right for user-defined classes. Store the # lvalue type in mypy and remove this special case. lvalue_type = ObjectRType() else: lvalue_type = self.node_type(lvalue) rvalue_type = self.node_type(stmt.rvalue) return self.assign(lvalue, stmt.rvalue, rvalue_type, lvalue_type, declare_new=(stmt.type is not None)) def visit_operator_assignment_stmt( self, stmt: OperatorAssignmentStmt) -> Register: target = self.get_assignment_target(stmt.lvalue, declare_new=False) if isinstance(target, AssignmentTargetRegister): ltype = self.environment.types[target.register] rtype = self.node_type(stmt.rvalue) rreg = self.accept(stmt.rvalue) return self.binary_op(ltype, target.register, rtype, rreg, stmt.op, target=target.register) # NOTE: List index not supported yet for compound assignments. assert False, 'Unsupported lvalue: %r' def get_assignment_target(self, lvalue: Lvalue, declare_new: bool) -> AssignmentTarget: if isinstance(lvalue, NameExpr): # Assign to local variable. assert lvalue.kind == LDEF if lvalue.is_def or declare_new: # Define a new variable. assert isinstance(lvalue.node, Var) # TODO: Can this fail? lvalue_num = self.environment.add_local( lvalue.node, self.node_type(lvalue)) else: # Assign to a previously defined variable. assert isinstance(lvalue.node, Var) # TODO: Can this fail? lvalue_num = self.environment.lookup(lvalue.node) return AssignmentTargetRegister(lvalue_num) elif isinstance(lvalue, IndexExpr): # Indexed assignment x[y] = e base_type = self.node_type(lvalue.base) index_type = self.node_type(lvalue.index) base_reg = self.accept(lvalue.base) index_reg = self.accept(lvalue.index) if isinstance(base_type, ListRType) and isinstance( index_type, IntRType): # Indexed list set return AssignmentTargetIndex(base_reg, index_reg, base_type) elif isinstance(base_type, DictRType): # Indexed dict set boxed_index = self.box(index_reg, index_type) return AssignmentTargetIndex(base_reg, boxed_index, base_type) elif isinstance(lvalue, MemberExpr): # Attribute assignment x.y = e obj_type = self.node_type(lvalue.expr) assert isinstance( obj_type, UserRType), 'Attribute set only supported for user types' obj_reg = self.accept(lvalue.expr) return AssignmentTargetAttr(obj_reg, lvalue.name, obj_type) assert False, 'Unsupported lvalue: %r' % lvalue def assign_to_target(self, target: AssignmentTarget, rvalue: Expression, rvalue_type: RType, needs_box: bool) -> Register: rvalue_type = rvalue_type or self.node_type(rvalue) if isinstance(target, AssignmentTargetRegister): if needs_box: unboxed = self.accept(rvalue) return self.box(unboxed, rvalue_type, target=target.register) else: return self.accept(rvalue, target=target.register) elif isinstance(target, AssignmentTargetAttr): rvalue_reg = self.accept(rvalue) if needs_box: rvalue_reg = self.box(rvalue_reg, rvalue_type) self.add( SetAttr(target.obj_reg, target.attr, rvalue_reg, target.obj_type)) return INVALID_REGISTER elif isinstance(target, AssignmentTargetIndex): item_reg = self.accept(rvalue) boxed_item_reg = self.box(item_reg, rvalue_type) if isinstance(target.rtype, ListRType): op = PrimitiveOp.LIST_SET elif isinstance(target.rtype, DictRType): op = PrimitiveOp.DICT_SET else: assert False, target.rtype self.add( PrimitiveOp(None, op, target.base_reg, target.index_reg, boxed_item_reg)) return INVALID_REGISTER assert False, 'Unsupported assignment target' def assign(self, lvalue: Lvalue, rvalue: Expression, rvalue_type: RType, lvalue_type: RType, declare_new: bool) -> Register: target = self.get_assignment_target(lvalue, declare_new) needs_box = rvalue_type.supports_unbox and not lvalue_type.supports_unbox return self.assign_to_target(target, rvalue, rvalue_type, needs_box) def visit_if_stmt(self, stmt: IfStmt) -> Register: # If statements are normalized assert len(stmt.expr) == 1 branches = self.process_conditional(stmt.expr[0]) if_body = self.new_block() self.set_branches(branches, True, if_body) stmt.body[0].accept(self) if_leave = self.add_leave() if stmt.else_body: else_body = self.new_block() self.set_branches(branches, False, else_body) stmt.else_body.accept(self) else_leave = self.add_leave() next = self.new_block() if else_leave: else_leave.label = next.label else: # No else block. next = self.new_block() self.set_branches(branches, False, next) if if_leave: if_leave.label = next.label return INVALID_REGISTER def add_leave(self) -> Optional[Goto]: if not self.blocks[-1][-1].ops or not isinstance( self.blocks[-1][-1].ops[-1], Return): leave = Goto(INVALID_LABEL) self.add(leave) return leave return None def push_loop_stack(self) -> None: self.break_gotos.append([]) self.continue_gotos.append([]) def pop_loop_stack(self, continue_block: BasicBlock, break_block: BasicBlock) -> None: for continue_goto in self.continue_gotos.pop(): continue_goto.label = continue_block.label for break_goto in self.break_gotos.pop(): break_goto.label = break_block.label def visit_while_stmt(self, s: WhileStmt) -> Register: self.push_loop_stack() # Split block so that we get a handle to the top of the loop. goto = Goto(INVALID_LABEL) self.add(goto) top = self.new_block() goto.label = top.label branches = self.process_conditional(s.expr) body = self.new_block() # Bind "true" branches to the body block. self.set_branches(branches, True, body) s.body.accept(self) # Add branch to the top at the end of the body. self.add(Goto(top.label)) next = self.new_block() # Bind "false" branches to the new block. self.set_branches(branches, False, next) self.pop_loop_stack(top, next) return INVALID_REGISTER def visit_for_stmt(self, s: ForStmt) -> Register: if (isinstance(s.expr, CallExpr) and isinstance(s.expr.callee, RefExpr) and s.expr.callee.fullname == 'builtins.range'): self.push_loop_stack() # Special case for x in range(...) # TODO: Check argument counts and kinds; check the lvalue end = s.expr.args[0] end_reg = self.accept(end) # Initialize loop index to 0. index_reg = self.assign(s.index, IntExpr(0), IntRType(), IntRType(), declare_new=True) goto = Goto(INVALID_LABEL) self.add(goto) # Add loop condition check. top = self.new_block() goto.label = top.label branch = Branch(index_reg, end_reg, INVALID_LABEL, INVALID_LABEL, Branch.INT_LT) self.add(branch) branches = [branch] body = self.new_block() self.set_branches(branches, True, body) s.body.accept(self) end_goto = Goto(INVALID_LABEL) self.add(end_goto) end_block = self.new_block() end_goto.label = end_block.label # Increment index register. one_reg = self.alloc_temp(IntRType()) self.add(LoadInt(one_reg, 1)) self.add( PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg, one_reg)) # Go back to loop condition check. self.add(Goto(top.label)) next = self.new_block() self.set_branches(branches, False, next) self.pop_loop_stack(end_block, next) return INVALID_REGISTER if self.node_type(s.expr).name == 'list': self.push_loop_stack() expr_reg = self.accept(s.expr) index_reg = self.alloc_temp(IntRType()) self.add(LoadInt(index_reg, 0)) one_reg = self.alloc_temp(IntRType()) self.add(LoadInt(one_reg, 1)) assert isinstance(s.index, NameExpr) assert isinstance(s.index.node, Var) lvalue_reg = self.environment.add_local(s.index.node, self.node_type(s.index)) condition_block = self.goto_new_block() # For compatibility with python semantics we recalculate the length # at every iteration. len_reg = self.alloc_temp(IntRType()) self.add(PrimitiveOp(len_reg, PrimitiveOp.LIST_LEN, expr_reg)) branch = Branch(index_reg, len_reg, INVALID_LABEL, INVALID_LABEL, Branch.INT_LT) self.add(branch) branches = [branch] body_block = self.new_block() self.set_branches(branches, True, body_block) target_list_type = self.types[s.expr] assert isinstance(target_list_type, Instance) target_type = self.type_to_rtype(target_list_type.args[0]) value_box = self.alloc_temp(ObjectRType()) self.add( PrimitiveOp(value_box, PrimitiveOp.LIST_GET, expr_reg, index_reg)) self.unbox_or_cast(value_box, target_type, target=lvalue_reg) s.body.accept(self) end_block = self.goto_new_block() self.add( PrimitiveOp(index_reg, PrimitiveOp.INT_ADD, index_reg, one_reg)) self.add(Goto(condition_block.label)) next_block = self.new_block() self.set_branches(branches, False, next_block) self.pop_loop_stack(end_block, next_block) return INVALID_REGISTER assert False, 'for not supported' def visit_break_stmt(self, node: BreakStmt) -> Register: self.break_gotos[-1].append(Goto(INVALID_LABEL)) self.add(self.break_gotos[-1][-1]) return INVALID_REGISTER def visit_continue_stmt(self, node: ContinueStmt) -> Register: self.continue_gotos[-1].append(Goto(INVALID_LABEL)) self.add(self.continue_gotos[-1][-1]) return INVALID_REGISTER int_binary_ops = { '+': PrimitiveOp.INT_ADD, '-': PrimitiveOp.INT_SUB, '*': PrimitiveOp.INT_MUL, '//': PrimitiveOp.INT_DIV, '%': PrimitiveOp.INT_MOD, '&': PrimitiveOp.INT_AND, '|': PrimitiveOp.INT_OR, '^': PrimitiveOp.INT_XOR, '<<': PrimitiveOp.INT_SHL, '>>': PrimitiveOp.INT_SHR, '>>': PrimitiveOp.INT_SHR, } def visit_unary_expr(self, expr: UnaryExpr) -> Register: if expr.op != '-': assert False, 'Unsupported unary operation' etype = self.node_type(expr.expr) reg = self.accept(expr.expr) if etype.name != 'int': assert False, 'Unsupported unary operation' target = self.alloc_target(IntRType()) zero = self.accept(IntExpr(0)) self.add(PrimitiveOp(target, PrimitiveOp.INT_SUB, zero, reg)) return target def visit_op_expr(self, expr: OpExpr) -> Register: ltype = self.node_type(expr.left) rtype = self.node_type(expr.right) lreg = self.accept(expr.left) rreg = self.accept(expr.right) return self.binary_op(ltype, lreg, rtype, rreg, expr.op) def binary_op(self, ltype: RType, lreg: Register, rtype: RType, rreg: Register, expr_op: str, target: Optional[Register] = None) -> Register: if ltype.name == 'int' and rtype.name == 'int': # Primitive int operation if target is None: target = self.alloc_target(IntRType()) op = self.int_binary_ops[expr_op] elif (ltype.name == 'list' or rtype.name == 'list') and expr_op == '*': if rtype.name == 'list': ltype, rtype = rtype, ltype lreg, rreg = rreg, lreg if rtype.name != 'int': assert False, 'Unsupported binary operation' # TODO: Operator overloading if target is None: target = self.alloc_target(ListRType()) op = PrimitiveOp.LIST_REPEAT elif isinstance(rtype, DictRType): if expr_op == 'in': if target is None: target = self.alloc_target(BoolRType()) lreg = self.box(lreg, ltype) op = PrimitiveOp.DICT_CONTAINS else: assert False, 'Unsupported binary operation' else: assert False, 'Unsupported binary operation' self.add(PrimitiveOp(target, op, lreg, rreg)) return target def visit_index_expr(self, expr: IndexExpr) -> Register: base_rtype = self.node_type(expr.base) base_reg = self.accept(expr.base) target_type = self.node_type(expr) if isinstance(base_rtype, (ListRType, SequenceTupleRType, DictRType)): index_type = self.node_type(expr.index) if not isinstance(base_rtype, DictRType): assert isinstance( index_type, IntRType), 'Unsupported indexing operation' # TODO if isinstance(base_rtype, ListRType): op = PrimitiveOp.LIST_GET elif isinstance(base_rtype, DictRType): op = PrimitiveOp.DICT_GET else: op = PrimitiveOp.HOMOGENOUS_TUPLE_GET index_reg = self.accept(expr.index) if isinstance(base_rtype, DictRType): index_reg = self.box(index_reg, index_type) tmp = self.alloc_temp(ObjectRType()) self.add(PrimitiveOp(tmp, op, base_reg, index_reg)) target = self.alloc_target(target_type) return self.unbox_or_cast(tmp, target_type, target) elif isinstance(base_rtype, TupleRType): assert isinstance(expr.index, IntExpr) # TODO target = self.alloc_target(target_type) self.add( TupleGet(target, base_reg, expr.index.value, base_rtype.types[expr.index.value])) return target assert False, 'Unsupported indexing operation' def visit_int_expr(self, expr: IntExpr) -> Register: reg = self.alloc_target(IntRType()) self.add(LoadInt(reg, expr.value)) return reg def is_native_name_expr(self, expr: NameExpr) -> bool: # TODO later we want to support cross-module native calls too if '.' in expr.node.fullname(): module_name = '.'.join(expr.node.fullname().split('.')[:-1]) return module_name == self.current_module_name return True def visit_name_expr(self, expr: NameExpr) -> Register: if expr.node.fullname() == 'builtins.None': target = self.alloc_target(NoneRType()) self.add(PrimitiveOp(target, PrimitiveOp.NONE)) return target elif expr.node.fullname() == 'builtins.True': target = self.alloc_target(BoolRType()) self.add(PrimitiveOp(target, PrimitiveOp.TRUE)) return target elif expr.node.fullname() == 'builtins.False': target = self.alloc_target(BoolRType()) self.add(PrimitiveOp(target, PrimitiveOp.FALSE)) return target if not self.is_native_name_expr(expr): return self.load_static_module_attr(expr) # TODO: We assume that this is a Var node, which is very limited assert isinstance(expr.node, Var) reg = self.environment.lookup(expr.node) return self.get_using_binder(reg, expr.node, expr) def get_using_binder(self, reg: Register, var: Var, expr: Expression) -> Register: var_type = self.type_to_rtype(var.type) target_type = self.node_type(expr) if var_type != target_type: # Cast/unbox to the narrower given by the binder. if self.targets[-1] < 0: target = self.alloc_temp(target_type) else: target = self.targets[-1] return self.unbox_or_cast(reg, target_type, target) else: # Regular register access -- binder is not active. if self.targets[-1] < 0: return reg else: target = self.targets[-1] self.add(Assign(target, reg)) return target def is_module_member_expr(self, expr: MemberExpr): return isinstance(expr.expr, RefExpr) and expr.expr.kind == MODULE_REF def visit_member_expr(self, expr: MemberExpr) -> Register: if self.is_module_member_expr(expr): return self.load_static_module_attr(expr) else: obj_reg = self.accept(expr.expr) attr_type = self.node_type(expr) target = self.alloc_target(attr_type) obj_type = self.node_type(expr.expr) assert isinstance( obj_type, UserRType), 'Attribute access not supported: %s' % obj_type self.add(GetAttr(target, obj_reg, expr.name, obj_type)) return target def load_static_module_attr(self, expr: RefExpr) -> Register: target = self.alloc_target(self.node_type(expr)) module = '.'.join(expr.node.fullname().split('.')[:-1]) right = expr.node.fullname().split('.')[-1] left = self.alloc_temp(ObjectRType()) self.add(LoadStatic(left, c_module_name(module))) self.add(PyGetAttr(target, left, right)) return target def py_call(self, function: Register, args: List[Expression], target_type: RType) -> Register: target_box = self.alloc_temp(ObjectRType()) arg_boxes = [] # type: List[Register] for arg_expr in args: arg_reg = self.accept(arg_expr) arg_boxes.append(self.box(arg_reg, self.node_type(arg_expr))) self.add(PyCall(target_box, function, arg_boxes)) return self.unbox_or_cast(target_box, target_type) def visit_call_expr(self, expr: CallExpr) -> Register: if isinstance(expr.callee, MemberExpr): is_module_call = self.is_module_member_expr(expr.callee) if expr.callee.expr in self.types and not is_module_call: target = self.translate_special_method_call(expr.callee, expr) if target: return target # Either its a module call or translating to a special method call failed, so we have # to fallback to a PyCall function = self.accept(expr.callee) return self.py_call(function, expr.args, self.node_type(expr)) assert isinstance(expr.callee, NameExpr) fn = expr.callee.name # TODO: fullname if fn == 'len' and len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: target = self.alloc_target(IntRType()) arg = self.accept(expr.args[0]) expr_rtype = self.node_type(expr.args[0]) if expr_rtype.name == 'list': self.add(PrimitiveOp(target, PrimitiveOp.LIST_LEN, arg)) elif expr_rtype.name == 'sequence_tuple': self.add( PrimitiveOp(target, PrimitiveOp.HOMOGENOUS_TUPLE_LEN, arg)) elif isinstance(expr_rtype, TupleRType): self.add(LoadInt(target, len(expr_rtype.types))) else: assert False, "unsupported use of len" # Handle conversion to sequence tuple elif fn == 'tuple' and len( expr.args) == 1 and expr.arg_kinds == [ARG_POS]: target = self.alloc_target(SequenceTupleRType()) arg = self.accept(expr.args[0]) self.add( PrimitiveOp(target, PrimitiveOp.LIST_TO_HOMOGENOUS_TUPLE, arg)) else: target_type = self.node_type(expr) if not (self.is_native_name_expr(expr.callee)): function = self.accept(expr.callee) return self.py_call(function, expr.args, target_type) target = self.alloc_target(target_type) args = [self.accept(arg) for arg in expr.args] self.add(Call(target, fn, args)) return target def visit_conditional_expr(self, expr: ConditionalExpr) -> Register: branches = self.process_conditional(expr.cond) target = self.alloc_target(self.node_type(expr)) if_body = self.new_block() self.set_branches(branches, True, if_body) self.accept(expr.if_expr, target=target) if_goto_next = Goto(INVALID_LABEL) self.add(if_goto_next) else_body = self.new_block() self.set_branches(branches, False, else_body) self.accept(expr.else_expr, target=target) else_goto_next = Goto(INVALID_LABEL) self.add(else_goto_next) next = self.new_block() if_goto_next.label = next.label else_goto_next.label = next.label return target def translate_special_method_call(self, callee: MemberExpr, expr: CallExpr) -> Register: base_type = self.node_type(callee.expr) result_type = self.node_type(expr) base = self.accept(callee.expr) if callee.name == 'append' and base_type.name == 'list': target = INVALID_REGISTER # TODO: Do we sometimes need to allocate a register? arg = self.box_expr(expr.args[0]) self.add(PrimitiveOp(target, PrimitiveOp.LIST_APPEND, base, arg)) else: assert False, 'Unsupported method call: %s.%s' % (base_type.name, callee.name) return target def visit_list_expr(self, expr: ListExpr) -> Register: list_type = self.types[expr] assert isinstance(list_type, Instance) item_type = self.type_to_rtype(list_type.args[0]) target = self.alloc_target(ListRType()) items = [] for item in expr.items: item_reg = self.accept(item) boxed = self.box(item_reg, item_type) items.append(boxed) self.add(PrimitiveOp(target, PrimitiveOp.NEW_LIST, *items)) return target def visit_tuple_expr(self, expr: TupleExpr) -> Register: tuple_type = self.types[expr] assert isinstance(tuple_type, TupleType) target = self.alloc_target(self.type_to_rtype(tuple_type)) items = [self.accept(i) for i in expr.items] self.add(PrimitiveOp(target, PrimitiveOp.NEW_TUPLE, *items)) return target def visit_dict_expr(self, expr: DictExpr): assert not expr.items # TODO target = self.alloc_target(DictRType()) self.add(PrimitiveOp(target, PrimitiveOp.NEW_DICT)) return target # Conditional expressions int_relative_ops = { '==': Branch.INT_EQ, '!=': Branch.INT_NE, '<': Branch.INT_LT, '<=': Branch.INT_LE, '>': Branch.INT_GT, '>=': Branch.INT_GE, } def process_conditional(self, e: Node) -> List[Branch]: if isinstance(e, ComparisonExpr): # TODO: Verify operand types. assert len(e.operators) == 1, 'more than 1 operator not supported' op = e.operators[0] if op in ['==', '!=', '<', '<=', '>', '>=']: # TODO: check operand types left = self.accept(e.operands[0]) right = self.accept(e.operands[1]) opcode = self.int_relative_ops[op] branch = Branch(left, right, INVALID_LABEL, INVALID_LABEL, opcode) elif op in ['is', 'is not']: # TODO: check if right operand is None left = self.accept(e.operands[0]) branch = Branch(left, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.IS_NONE) if op == 'is not': branch.negated = True elif op in ['in', 'not in']: left = self.accept(e.operands[0]) ltype = self.node_type(e.operands[0]) right = self.accept(e.operands[1]) rtype = self.node_type(e.operands[1]) target = self.alloc_temp(self.node_type(e)) self.binary_op(ltype, left, rtype, right, 'in', target=target) branch = Branch(target, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.BOOL_EXPR) if op == 'not in': branch.negated = True else: assert False, "unsupported comparison epxression" self.add(branch) return [branch] elif isinstance(e, OpExpr) and e.op in ['and', 'or']: if e.op == 'and': # Short circuit 'and' in a conditional context. lbranches = self.process_conditional(e.left) new = self.new_block() self.set_branches(lbranches, True, new) rbranches = self.process_conditional(e.right) return lbranches + rbranches else: # Short circuit 'or' in a conditional context. lbranches = self.process_conditional(e.left) new = self.new_block() self.set_branches(lbranches, False, new) rbranches = self.process_conditional(e.right) return lbranches + rbranches elif isinstance(e, UnaryExpr) and e.op == 'not': branches = self.process_conditional(e.expr) for b in branches: b.invert() return branches # Catch-all for arbitrary expressions. else: reg = self.accept(e) branch = Branch(reg, INVALID_REGISTER, INVALID_LABEL, INVALID_LABEL, Branch.BOOL_EXPR) self.add(branch) return [branch] def set_branches(self, branches: List[Branch], condition: bool, target: BasicBlock) -> None: """Set branch targets for the given condition (True or False). If the target has already been set for a branch, skip the branch. """ for b in branches: if condition: if b.true < 0: b.true = target.label else: if b.false < 0: b.false = target.label # Helpers def enter(self) -> None: self.environment = Environment() self.environments.append(self.environment) self.blocks.append([]) self.new_block() def new_block(self) -> BasicBlock: new = BasicBlock(Label(len(self.blocks[-1]))) self.blocks[-1].append(new) return new def goto_new_block(self) -> BasicBlock: goto = Goto(INVALID_LABEL) self.add(goto) block = self.new_block() goto.label = block.label return block def leave(self) -> Tuple[List[BasicBlock], Environment]: blocks = self.blocks.pop() env = self.environments.pop() self.environment = self.environments[-1] return blocks, env def add(self, op: Op) -> None: self.blocks[-1][-1].ops.append(op) def accept(self, node: Node, target: Register = INVALID_REGISTER) -> Register: self.targets.append(target) actual = node.accept(self) self.targets.pop() return actual def alloc_target(self, type: RType) -> Register: if self.targets[-1] < 0: return self.environment.add_temp(type) else: return self.targets[-1] def alloc_temp(self, type: RType) -> Register: return self.environment.add_temp(type) def type_to_rtype(self, typ: Type) -> RType: return self.mapper.type_to_rtype(typ) def node_type(self, node: Expression) -> RType: mypy_type = self.types[node] return self.type_to_rtype(mypy_type) def box(self, src: Register, typ: RType, target: Optional[Register] = None) -> Register: if typ.supports_unbox: if target is None: target = self.alloc_temp(ObjectRType()) self.add(Box(target, src, typ)) return target else: # Already boxed if target is not None: self.add(Assign(target, src)) return target else: return src def unbox_or_cast(self, src: Register, target_type: RType, target: Optional[Register] = None) -> Register: if target is None: target = self.alloc_temp(target_type) if target_type.supports_unbox: self.add(Unbox(target, src, target_type)) else: self.add(Cast(target, src, target_type)) return target def box_expr(self, expr: Expression) -> Register: typ = self.node_type(expr) return self.box(self.accept(expr), typ)