def visit_Name(self, node): if node.id == "True" or node.id == "False": return CLikeTranspiler().visit(node) var = node.scopes.find(node.id) if defined_before(var, node): return node.id else: return self.visit(var.assigned_from.value)
def _visit_body(self, body): unpacked = [] for s in body: do_unpack = getattr(s, "unpack", True) if isinstance( s, ast.If) and CLikeTranspiler.is_block(s) and do_unpack: unpacked.extend(self._visit_body(s.body)) else: unpacked.append(s) return unpacked
def visit_NameConstant(self, node): return CLikeTranspiler().visit(node)
def visit_BinOp(self, node): return "{0} {1} {2}".format( self.visit(node.left), CLikeTranspiler().visit(node.op), self.visit(node.right), )
def __init__(self): self.handling_annotation = False self.has_fixed_width_ints = False # TODO: remove this and make the methods into classmethods self._clike = CLikeTranspiler()
class InferTypesTransformer(ast.NodeTransformer): """ Tries to infer types """ TYPE_DICT = {int: "int", float: "float", str: "str", bool: "bool"} FIXED_WIDTH_INTS = { bool, c_int8, c_int16, c_int32, c_int64, c_uint8, c_uint16, c_uint32, c_uint64, } FIXED_WIDTH_INTS_NAME_LIST = [ "bool", "c_int8", "c_int16", "c_int32", "c_int64", "c_uint8", "c_uint16", "c_uint32", "c_uint64", ] FIXED_WIDTH_INTS_NAME = set(FIXED_WIDTH_INTS_NAME_LIST) def __init__(self): self.handling_annotation = False self.has_fixed_width_ints = False # TODO: remove this and make the methods into classmethods self._clike = CLikeTranspiler() @staticmethod def _infer_primitive(value) -> Optional[ast.AST]: t = type(value) annotation = None if t in InferTypesTransformer.TYPE_DICT: annotation = ast.Name(id=InferTypesTransformer.TYPE_DICT[t]) elif t in InferTypesTransformer.FIXED_WIDTH_INTS: annotation = ast.Name(id=str(t)) elif t != type(None): raise NotImplementedError(f"{t} not found in TYPE_DICT") return annotation def visit_NameConstant(self, node): if node.value is Ellipsis: return node annotation = self._infer_primitive(node.value) if annotation is not None: node.annotation = annotation self.generic_visit(node) return node def visit_Name(self, node): annotation = get_inferred_type(node) if annotation is not None: node.annotation = annotation return node def visit_Constant(self, node): return self.visit_NameConstant(node) @staticmethod def _annotate(node, typename: str): # ast.parse produces a Module object that needs to be destructured type_annotation = ast.parse(typename).body[0].value node.annotation = type_annotation def visit_List(self, node): self.generic_visit(node) if len(node.elts) > 0: elements = [self.visit(e) for e in node.elts] if getattr(node, "is_annotation", False): return node else: elt_types = set( [get_id(get_inferred_type(e)) for e in elements]) if len(elt_types) == 1 and hasattr(elements[0], "annotation"): elt_type = get_id(elements[0].annotation) self._annotate(node, f"List[{elt_type}]") else: if not hasattr(node, "annotation"): node.annotation = ast.Name(id="List") return node def visit_Set(self, node): self.generic_visit(node) if len(node.elts) > 0: elements = [self.visit(e) for e in node.elts] elt_types = set([get_id(get_inferred_type(e)) for e in elements]) if len(elt_types) == 1: elt_type = get_id(elements[0].annotation) self._annotate(node, f"Set[{elt_type}]") else: if not hasattr(node, "annotation"): node.annotation = ast.Name(id="Set") return node def visit_Dict(self, node): self.generic_visit(node) if len(node.keys) > 0: def typename(e): get_inferred_type(e) # populates e.annotation return self._clike._generic_typename_from_annotation(e) key_types = set([typename(e) for e in node.keys]) only_key_type = next(iter(key_types)) if len(key_types) == 1: key_type = only_key_type else: key_type = "Any" value_types = set([typename(e) for e in node.values]) only_value_type = next(iter(value_types)) if len(value_types) == 1: value_type = only_value_type else: value_type = "Any" self._annotate(node, f"Dict[{key_type}, {value_type}]") else: if not hasattr(node, "annotation"): node.annotation = ast.Name(id="Dict") return node def visit_Assign(self, node: ast.Assign) -> ast.AST: self.generic_visit(node) target = node.targets[0] annotation = get_inferred_type(node.value) if annotation is not None: target.annotation = annotation return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: self.generic_visit(node) node.target.annotation = node.annotation if get_id(node.annotation) in self.FIXED_WIDTH_INTS_NAME: self.has_fixed_width_ints = True return node def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: self.generic_visit(node) target = node.target annotation = get_inferred_type(target) if hasattr(node.value, "annotation") and not annotation: target.annotation = node.value.annotation else: target.annotation = annotation return node def visit_Compare(self, node): self.generic_visit(node) node.annotation = ast.Name(id="bool") return node def visit_Return(self, node): self.generic_visit(node) new_type_str = (get_id(node.value.annotation) if hasattr( node.value, "annotation") else None) if new_type_str is None: return node for scope in node.scopes: type_str = None if isinstance(scope, ast.FunctionDef): type_str = get_id(scope.returns) if type_str is not None: if new_type_str != type_str: type_str = f"Union[{type_str},{new_type_str}]" scope.returns.id = type_str else: # Do not overwrite source annotation with inferred if scope.returns is None: scope.returns = ast.Name(id=new_type_str) return node def visit_UnaryOp(self, node): self.generic_visit(node) if isinstance(node.operand, ast.Name): operand = node.scopes.find(get_id(node.operand)) else: operand = node.operand if hasattr(operand, "annotation"): node.annotation = operand.annotation return node def _handle_overflow(self, op, left_id, right_id): widening_op = isinstance(op, ast.Add) or isinstance(op, ast.Mult) left_idx = (self.FIXED_WIDTH_INTS_NAME_LIST.index(left_id) if left_id in self.FIXED_WIDTH_INTS_NAME else -1) right_idx = (self.FIXED_WIDTH_INTS_NAME_LIST.index(right_id) if right_id in self.FIXED_WIDTH_INTS_NAME else -1) max_idx = max(left_idx, right_idx) cint64_idx = self.FIXED_WIDTH_INTS_NAME_LIST.index("c_int64") if widening_op: if max_idx not in { -1, cint64_idx, len(self.FIXED_WIDTH_INTS_NAME_LIST) - 1, }: # i8 + i8 => i16 for example return self.FIXED_WIDTH_INTS_NAME_LIST[max_idx + 1] if left_id == "float" or right_id == "float": return "float" return left_id if left_idx > right_idx else right_id def visit_BinOp(self, node): self.generic_visit(node) if isinstance(node.left, ast.Name): lvar = node.scopes.find(get_id(node.left)) else: lvar = node.left if isinstance(node.right, ast.Name): rvar = node.scopes.find(get_id(node.right)) else: rvar = node.right left = lvar.annotation if lvar and hasattr(lvar, "annotation") else None right = rvar.annotation if rvar and hasattr(rvar, "annotation") else None if left is None and right is not None: node.annotation = right return node if right is None and left is not None: node.annotation = left return node if right is None and left is None: return node # Both operands are annotated. Now we have interesting cases left_id = get_id(left) right_id = get_id(right) if left_id == right_id and left_id == "int": if not isinstance(node.op, ast.Div) or getattr( node, "use_integer_div", False): node.annotation = left else: # TODO: This is not true for dart when using integer division node.annotation = ast.Name(id="float") return node # Does this hold across all languages? if left_id == "int": left_id = "c_int32" if right_id == "int": right_id = "c_int32" if (left_id in self.FIXED_WIDTH_INTS_NAME and right_id in self.FIXED_WIDTH_INTS_NAME): ret = self._handle_overflow(node.op, left_id, right_id) node.annotation = ast.Name(id=ret) return node if left_id == right_id: # Exceptions: division operator if isinstance(node.op, ast.Div): if left_id == "int": node.annotation = ast.Name(id="float") return node node.annotation = left return node else: if left_id in self.FIXED_WIDTH_INTS_NAME: left_id = "int" if right_id in self.FIXED_WIDTH_INTS_NAME: right_id = "int" if (left_id, right_id) in {("int", "float"), ("float", "int")}: node.annotation = ast.Name(id="float") return node raise Exception( f"type error: {left_id} {type(node.op)} {right_id}") return node def visit_ClassDef(self, node): node.annotation = ast.Name(id=node.name) return node def visit_Attribute(self, node): value_id = get_id(node.value) if value_id is not None and hasattr(node, "scopes"): if is_enum(value_id, node.scopes): node.annotation = node.scopes.find(value_id) return node def visit_Call(self, node): fname = get_id(node.func) if fname is not None: fn = node.scopes.find(fname) if isinstance(fn, ast.ClassDef): node.annotation = fn elif isinstance(fn, ast.FunctionDef): return_type = (fn.returns if hasattr(fn, "returns") and fn.returns else None) if return_type is not None: node.annotation = return_type elif fname in {"max", "min"}: return_type = get_inferred_type(node.args[0]) if return_type is not None: node.annotation = return_type elif fname in self.TYPE_DICT.values(): node.annotation = ast.Name(id=fname) self.generic_visit(node) return node def visit_Subscript(self, node): definition = node.scopes.find(get_id(node.value)) if hasattr(definition, "annotation"): self._clike._typename_from_annotation(definition) if hasattr(definition, "container_type"): container_type, element_type = definition.container_type if container_type == "Dict" or isinstance(element_type, list): element_type = element_type[1] node.annotation = ast.Name(id=element_type) self.generic_visit(node) return node
class InferTypesTransformer(ast.NodeTransformer): """ Tries to infer types """ TYPE_DICT = { int: "int", float: "float", str: "str", bool: "bool", bytes: "bytes", complex: "complex", type(...): "...", } FIXED_WIDTH_INTS_LIST = [ bool, c_int8, c_int16, c_int32, c_int64, c_uint8, c_uint16, c_uint32, c_uint64, ] FIXED_WIDTH_INTS = set(FIXED_WIDTH_INTS_LIST) FIXED_WIDTH_BIT_LENGTH = { bool: 1, c_int8: 7, c_uint8: 8, c_int16: 15, c_uint16: 16, # This is based on how int maps to i32 on many platforms int: 31, c_int32: 31, c_uint32: 32, c_int64: 63, c_uint64: 64, } # The order needs to match FIXED_WIDTH_INTS_LIST. Extra elements ok. FIXED_WIDTH_INTS_NAME_LIST = [ "bool", "c_int8", "c_int16", "c_int32", "c_int64", "c_uint8", "c_uint16", "c_uint32", "c_uint64", "i8", "i16", "i32", "i64", "isize", "ilong", "u8", "u16", "u32", "u64", "usize", "ulong", ] FIXED_WIDTH_INTS_NAME = set(FIXED_WIDTH_INTS_NAME_LIST) def __init__(self): self.handling_annotation = False self.has_fixed_width_ints = False # TODO: remove this and make the methods into classmethods self._clike = CLikeTranspiler() @staticmethod def _infer_primitive(value) -> Optional[ast.AST]: t = type(value) annotation = None if t in InferTypesTransformer.TYPE_DICT: annotation = ast.Name(id=InferTypesTransformer.TYPE_DICT[t]) elif t in InferTypesTransformer.FIXED_WIDTH_INTS: annotation = ast.Name(id=str(t)) elif t != type(None): raise NotImplementedError(f"{t} not found in TYPE_DICT") return annotation def visit_NameConstant(self, node): if node.value is Ellipsis: return node annotation = self._infer_primitive(node.value) if annotation is not None: node.annotation = annotation node.annotation.lifetime = ( LifeTime.STATIC if type(node.value) == str else LifeTime.UNKNOWN ) self.generic_visit(node) return node def visit_Name(self, node): annotation = get_inferred_type(node) if annotation is not None: node.annotation = annotation return node def visit_Constant(self, node): return self.visit_NameConstant(node) @staticmethod def _annotate(node, typename: str): # ast.parse produces a Module object that needs to be destructured type_annotation = cast(ast.Expr, create_ast_node(typename, node)).value node.annotation = type_annotation def visit_List(self, node): self.generic_visit(node) if len(node.elts) > 0: elements = [self.visit(e) for e in node.elts] if getattr(node, "is_annotation", False): return node else: elt_types = set([get_id(get_inferred_type(e)) for e in elements]) if len(elt_types) == 1 and hasattr(elements[0], "annotation"): elt_type = get_id(elements[0].annotation) self._annotate(node, f"List[{elt_type}]") else: if not hasattr(node, "annotation"): node.annotation = ast.Name(id="List") return node def visit_Set(self, node): self.generic_visit(node) if len(node.elts) > 0: elements = [self.visit(e) for e in node.elts] elt_types = set([get_id(get_inferred_type(e)) for e in elements]) if len(elt_types) == 1: if hasattr(elements[0], "annotation"): elt_type = get_id(elements[0].annotation) self._annotate(node, f"Set[{elt_type}]") return node if not hasattr(node, "annotation"): node.annotation = ast.Name(id="Set") return node def visit_Dict(self, node): self.generic_visit(node) if len(node.keys) > 0: def typename(e): get_inferred_type(e) # populates e.annotation return self._clike._generic_typename_from_annotation(e) key_types = set([typename(e) for e in node.keys]) only_key_type = next(iter(key_types)) if len(key_types) == 1: key_type = only_key_type else: key_type = "Any" value_types = set([typename(e) for e in node.values]) only_value_type = next(iter(value_types)) if len(value_types) == 1: value_type = only_value_type else: value_type = "Any" self._annotate(node, f"Dict[{key_type}, {value_type}]") lifetimes = set( [ getattr(e.annotation, "lifetime", None) for e in node.values if hasattr(e, "annotation") ] ) only_lifetime = next(iter(lifetimes)) if len(lifetimes) == 1 else None if len(lifetimes) == 1 and only_lifetime != None: lifetime = only_lifetime else: lifetime = LifeTime.UNKNOWN node.annotation.lifetime = lifetime else: if not hasattr(node, "annotation"): node.annotation = ast.Name(id="Dict") return node def visit_Assign(self, node: ast.Assign) -> ast.AST: self.generic_visit(node) annotation = getattr(node.value, "annotation", None) if annotation is None: return node for target in node.targets: target_has_annotation = hasattr(target, "annotation") inferred = ( getattr(target.annotation, "inferred", False) if target_has_annotation else False ) if not target_has_annotation or inferred: target.annotation = annotation target.annotation.inferred = True # TODO: Call is_compatible to check if the inferred and user provided annotations conflict return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: self.generic_visit(node) node.target.annotation = node.annotation target = node.target target_typename = self._clike._typename_from_annotation(target) if target_typename in self.FIXED_WIDTH_INTS_NAME: self.has_fixed_width_ints = True annotation = get_inferred_type(node.value) value_typename = self._clike._generic_typename_from_type_node(annotation) target_class = class_for_typename(target_typename, None) value_class = class_for_typename(value_typename, None) if ( not is_compatible(target_class, value_class, target, node.value) and target_class != None ): raise AstIncompatibleAssign( f"{target_class} incompatible with {value_class}", node ) return node def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: self.generic_visit(node) target = node.target annotation = getattr(node.value, "annotation", None) if annotation is not None and not hasattr(target, "annotation"): target.annotation = annotation return node def visit_Compare(self, node): self.generic_visit(node) node.annotation = ast.Name(id="bool") return node def visit_Return(self, node): self.generic_visit(node) new_type_str = ( get_id(node.value.annotation) if hasattr(node.value, "annotation") else None ) if new_type_str is None: return node for scope in node.scopes: type_str = None if isinstance(scope, ast.FunctionDef): type_str = get_id(scope.returns) if type_str is not None: if new_type_str != type_str: type_str = f"Union[{type_str},{new_type_str}]" scope.returns.id = type_str else: # Do not overwrite source annotation with inferred if scope.returns is None: scope.returns = ast.Name(id=new_type_str) lifetime = getattr(node.value.annotation, "lifetime", None) if lifetime is not None: scope.returns.lifetime = lifetime return node def visit_UnaryOp(self, node): self.generic_visit(node) if isinstance(node.operand, ast.Name): operand = node.scopes.find(get_id(node.operand)) else: operand = node.operand if hasattr(operand, "annotation"): node.annotation = operand.annotation return node def _handle_overflow(self, op, left_id, right_id): widening_op = isinstance(op, ast.Add) or isinstance(op, ast.Mult) left_class = class_for_typename(left_id, None) right_class = class_for_typename(right_id, None) left_idx = ( self.FIXED_WIDTH_INTS_LIST.index(left_class) if left_class in self.FIXED_WIDTH_INTS else -1 ) right_idx = ( self.FIXED_WIDTH_INTS_LIST.index(right_class) if right_class in self.FIXED_WIDTH_INTS else -1 ) max_idx = max(left_idx, right_idx) cint64_idx = self.FIXED_WIDTH_INTS_LIST.index(c_int64) if widening_op: if max_idx not in {-1, cint64_idx, len(self.FIXED_WIDTH_INTS_LIST) - 1}: # i8 + i8 => i16 for example return self.FIXED_WIDTH_INTS_NAME_LIST[max_idx + 1] if left_id == "float" or right_id == "float": return "float" return left_id if left_idx > right_idx else right_id def visit_BinOp(self, node): self.generic_visit(node) if isinstance(node.left, ast.Name): lvar = node.scopes.find(get_id(node.left)) else: lvar = node.left if isinstance(node.right, ast.Name): rvar = node.scopes.find(get_id(node.right)) else: rvar = node.right left = lvar.annotation if lvar and hasattr(lvar, "annotation") else None right = rvar.annotation if rvar and hasattr(rvar, "annotation") else None if left is None and right is not None: node.annotation = right return node if right is None and left is not None: node.annotation = left return node if right is None and left is None: return node # Both operands are annotated. Now we have interesting cases left_id = get_id(left) right_id = get_id(right) if left_id == right_id and left_id == "int": if not isinstance(node.op, ast.Div) or getattr( node, "use_integer_div", False ): node.annotation = left else: node.annotation = ast.Name(id="float") return node # Does this hold across all languages? if left_id == "int": left_id = "c_int32" if right_id == "int": right_id = "c_int32" if ( left_id in self.FIXED_WIDTH_INTS_NAME and right_id in self.FIXED_WIDTH_INTS_NAME ): ret = self._handle_overflow(node.op, left_id, right_id) node.annotation = ast.Name(id=ret) return node if left_id == right_id: # Exceptions: division operator if isinstance(node.op, ast.Div): if left_id == "int": node.annotation = ast.Name(id="float") return node node.annotation = left return node if left_id in self.FIXED_WIDTH_INTS_NAME: left_id = "int" if right_id in self.FIXED_WIDTH_INTS_NAME: right_id = "int" if (left_id, right_id) in {("int", "float"), ("float", "int")}: node.annotation = ast.Name(id="float") return node if (left_id, right_id) in { ("int", "complex"), ("complex", "int"), ("float", "complex"), ("complex", "float"), }: node.annotation = ast.Name(id="complex") return node # Container multiplication if isinstance(node.op, ast.Mult) and {left_id, right_id} in [ {"bytes", "int"}, {"str", "int"}, {"tuple", "int"}, {"List", "int"}, ]: node.annotation = ast.Name(id=left_id) return node LEGAL_COMBINATIONS = { ("str", ast.Mod), ("List", ast.Add), } if left_id is not None and (left_id, type(node.op)) not in LEGAL_COMBINATIONS: raise AstUnrecognisedBinOp(left_id, right_id, node) return node def visit_ClassDef(self, node): node.annotation = ast.Name(id=node.name) self.generic_visit(node) return node def visit_Attribute(self, node): value_id = get_id(node.value) if value_id is not None and hasattr(node, "scopes"): if is_enum(value_id, node.scopes): node.annotation = node.scopes.find(value_id) return node def visit_Call(self, node): fname = get_id(node.func) if fname is not None: fn = node.scopes.find(fname) if isinstance(fn, ast.ClassDef): node.annotation = fn elif isinstance(fn, ast.FunctionDef): return_type = ( fn.returns if hasattr(fn, "returns") and fn.returns else None ) if return_type is not None: node.annotation = return_type lifetime = getattr(fn.returns, "lifetime", None) if lifetime is not None: node.annotation.lifetime = lifetime elif fname in {"max", "min"}: return_type = get_inferred_type(node.args[0]) if return_type is not None: node.annotation = return_type elif fname in self.TYPE_DICT.values(): node.annotation = ast.Name(id=fname) self.generic_visit(node) return node def visit_Subscript(self, node): definition = node.scopes.find(get_id(node.value)) if hasattr(definition, "annotation"): self._clike._typename_from_annotation(definition) if hasattr(definition, "container_type"): container_type, element_type = definition.container_type if container_type == "Dict" or isinstance(element_type, list): element_type = element_type[1] node.annotation = ast.Name(id=element_type) if hasattr(definition.annotation, "lifetime"): node.annotation.lifetime = definition.annotation.lifetime self.generic_visit(node) return node