def infer_types(module: astroid.Module, node_type: type, expr: Callable) -> Dict[astroid.node_classes.NodeNG, str]: """ Infer the types of an attribute of all nodes of the same type in a module. :param module: The module node where all nodes are located in. :param node_type: Type of node of which the type will be inferred on a certain attribute. :param expr: Expression to extract the attribute from the node where the type will be inferred on. E.g., lambda node: node.func.expr.name :return: All nodes in the module of type 'node_type' with the inferred type of the attribute accessible with the expression 'expr'. """ nodes = ASTUtil.search_nodes(module, node_type) source_code = ASTUtil.get_source_code(module) mypy_code = TypeInference.add_reveal_type_calls( source_code, nodes, expr) mypy_result = TypeInference.run_mypy(mypy_code) try: mypy_types = TypeInference.parse_mypy_result(mypy_result) except SyntaxError as ex: mypy_code_split = mypy_code.splitlines() faulty_code = mypy_code_split[int(ex.lineno) - 1] if "; reveal_type(" in faulty_code: original_code = faulty_code.split("; reveal_type(")[0] mypy_code_split[int(ex.lineno) - 1] = original_code mypy_types = TypeInference.parse_mypy_result( "\n".join(mypy_code_split)) return TypeInference.combine_nodes_with_inferred_types( nodes, mypy_types) else: print("Skipping type checking of module {}: {}. Line {}: {}". format(module.name, ex.msg, ex.lineno, faulty_code)) return {} return TypeInference.combine_nodes_with_inferred_types( nodes, mypy_types)
def infer_variable_most_recent_full_types( module: astroid.Module) -> Dict[str, str]: """ When there is no stub available for a library (e.g., missing tensorflow-stubs), use this method instead of infer_types. Infer variable type in Assign nodes. :param module: code module :return: Dict witn variable names and their inferred type """ variables_with_full_types = {} nodes = ASTUtil.search_nodes(module, astroid.Assign) for node in nodes: if hasattr(node, "targets"): for var in node.targets: if hasattr(var, "name"): variable_name = var.name if isinstance(node.value, astroid.nodes.Const): variables_with_full_types[variable_name] = 'const' elif hasattr(node.value, "func"): call = node.value.func full_type = "" if hasattr(call, "attrname"): full_type = "." + call.attrname + full_type while hasattr(call, "expr"): call = call.expr if hasattr(call, "attrname"): full_type = "." + call.attrname + full_type if hasattr(call, "name"): full_type = call.name + full_type variables_with_full_types[ variable_name] = full_type return variables_with_full_types
def test_search_body(self): """Test whether the correct body is returned.""" module_tree = astroid.parse( """ f() """ ) node = module_tree.body[0] assert ASTUtil.search_body(node) == module_tree.body
def test_search_body_parent_module(self): """Test whether the module is returned when searching for the parent of its child.""" module_tree = astroid.parse( """ f() """ ) node = module_tree.body[0] assert ASTUtil.search_body_parent(node) == module_tree
def test_search_body_parent_function(self): """Test whether the function is returned when searching for the parent of its child.""" module_tree = astroid.parse( """ def f(): return 0 """ ) node = module_tree.body[0].body[0] assert ASTUtil.search_body_parent(node) == module_tree.body[0]
def handle(checker: BaseChecker, node: astroid.node_classes.NodeNG): """ Handle a generic exception thrown in a checker by printing an error message. :param checker: Checker where the exception is thrown from. :param node: Node which is visited while the exception is thrown. """ module = ASTUtil.search_module(node) print( "ERROR: Could not finish processing the checker {} on module {} at line {}. Continuing." .format(checker.name, module.name, node.lineno))
def test_search_nodes(self): """Test the search_nodes method.""" module_tree = astroid.parse( """ a = b.c(d) def e(f): return g.h(i) """ ) found = ASTUtil.search_nodes(module_tree, astroid.Call) # noinspection PyUnresolvedReferences assert len(found) == 2 and found[0].func.attrname == "c" and found[1].func.attrname == "h"
def test_get_source_code(self): """Test the get_source_code method.""" source_code = "a = b.c(d)" module_tree = astroid.parse(source_code) assert ASTUtil.get_source_code(module_tree) == source_code