def visit_Expr(self, node): if self.document is None: raise ValueError() children = list(ast.iter_child_nodes(node)) if len(children) != 1: raise ValueError() feature_name = None child = children[0] if isinstance(child, ast.Attribute) \ and child.attr in self.document.features: feature_name = child.attr else: raise ValueError() grandchildren = list(ast.iter_child_nodes(child)) if len(grandchildren) != 2: raise ValueError() grandchild = grandchildren[0] if isinstance(grandchild, ast.Name) \ and grandchild.id in self.locals \ and isinstance(self.locals[grandchild.id], self.document): self.doc = self.locals[grandchild.id] self.feature_name = feature_name else: raise ValueError()
def get_testcases(paths): """Walk each path in ``paths`` and return the test cases found. :param path: List o directories to find test modules and test cases. :return: A dict mapping a test module path and its test cases. """ testmodules = [] for path in paths: for dirpath, _, filenames in os.walk(path): for filename in filenames: if filename.startswith('test_') and filename.endswith('.py'): testmodules.append(os.path.join(dirpath, filename)) testcases = collections.OrderedDict() for testmodule in testmodules: testcases[testmodule] = [] with open(testmodule) as handler: for node in ast.iter_child_nodes(ast.parse(handler.read())): if isinstance(node, ast.ClassDef): # Class test methods class_name = node.name testcases[testmodule].extend([ TestFunction(subnode, class_name, testmodule) for subnode in ast.iter_child_nodes(node) if isinstance(subnode, ast.FunctionDef) and subnode.name.startswith('test_') ]) elif (isinstance(node, ast.FunctionDef) and node.name.startswith('test_')): # Module's test functions testcases[testmodule].append(TestFunction( node, testmodule=testmodule)) return testcases
def visit_Subscript(self,node): if self.subscriptnode is not None: repnode = list(ast.walk(self.subscriptnode)) if node in repnode: self.generic_visit(node) return self.subscriptnode = node t = ast.iter_child_nodes(node) res = [[t,0]] maxcount = 1 while len(res) >= 1: t = res[-1][0] for childnode in t: # print childnode if isinstance(childnode, _ast.Subscript) or isinstance(childnode, _ast.Tuple) or \ isinstance(childnode, _ast.Dict) or isinstance(childnode, _ast.List) or isinstance(childnode, _ast.Set): res[-1][1] = 1 else: res[-1][1] = 0 res.append([ast.iter_child_nodes(childnode),0]) break else: maxcount = max(maxcount,sum([flag for (item,flag) in res]) + 1) res.pop() continue # print maxcount if maxcount >= config['containerdepth']: self.result.append((6,self.fileName,node.lineno,maxcount)) self.generic_visit(node)
def nodes_equal(x, y): __tracebackhide__ = True assert type(x) == type(y), "Ast nodes do not have the same type: '%s' != '%s' " % ( type(x), type(y), ) if isinstance(x, (ast.Expr, ast.FunctionDef, ast.ClassDef)): assert ( x.lineno == y.lineno ), "Ast nodes do not have the same line number : %s != %s" % ( x.lineno, y.lineno, ) assert x.col_offset == y.col_offset, ( "Ast nodes do not have the same column offset number : %s != %s" % (x.col_offset, y.col_offset) ) for (xname, xval), (yname, yval) in zip(ast.iter_fields(x), ast.iter_fields(y)): assert xname == yname, ( "Ast nodes fields differ : %s (of type %s) != %s (of type %s)" % (xname, type(xval), yname, type(yval)) ) assert type(xval) == type(yval), ( "Ast nodes fields differ : %s (of type %s) != %s (of type %s)" % (xname, type(xval), yname, type(yval)) ) for xchild, ychild in zip(ast.iter_child_nodes(x), ast.iter_child_nodes(y)): assert nodes_equal(xchild, ychild), "Ast node children differs" return True
def get_trait_definition(self): """ Retrieve the Trait attribute definition """ # Get the class source and tokenize it. source = inspect.getsource(self.parent) nodes = ast.parse(source) for node in ast.iter_child_nodes(nodes): if isinstance(node, ClassDef): parent_node = node break else: return '' for node in ast.iter_child_nodes(parent_node): if isinstance(node, Assign): name = node.targets[0] if name.id == self.object_name: break else: return '' endlineno = name.lineno for item in ast.walk(node): if hasattr(item, 'lineno'): endlineno = max(endlineno, item.lineno) definition_lines = [ line.strip() for line in source.splitlines()[name.lineno-1:endlineno]] definition = ''.join(definition_lines) equal = definition.index('=') return definition[equal + 1:].lstrip()
def _iter_member_names(self): ''' iterate over assign and def member names, including class bases preserves order (bases members goes first) ''' all_bases_names = set(itertools.chain(*[dir(base) for base in self.__class__.__mro__])) for base in reversed(self.__class__.__mro__): if base.__name__ == 'object': continue class_src = '\n'.join(l for l in inspect.getsource(base). split('\n') if not l.lstrip().startswith('#')) classnodes = ast.iter_child_nodes(ast.parse(dedent(class_src))) names = [] for classnode in classnodes: for node in ast.iter_child_nodes(classnode): if isinstance(node, ast.FunctionDef): if node.name in Behaviour._omit_member_names: continue names += node.name, if node.name in all_bases_names: all_bases_names.remove(node.name) elif isinstance(node, ast.Assign): for target in node.targets: if target.id in Behaviour._omit_member_names: continue names += target.id, if target.id in all_bases_names: all_bases_names.remove(target.id) for name in names: if not name in all_bases_names: yield name
def get_test_methods(fname): seen_classes = {} out = [] lines = open(fname).readlines() lines = [x.rstrip('\r\n') for x in lines] a = ast.parse(("\n".join(lines)).rstrip() + '\n', fname) for cls in ast.iter_child_nodes(a): if isinstance(cls, ast.ClassDef) and is_test_class(cls): if cls.name in seen_classes: raise ValueError("Duplicate class %s in %s" \ % (cls.name, fname)) seen_classes[cls.name] = {} seen_methods = {} for meth in ast.iter_child_nodes(cls): if isinstance(meth, ast.FunctionDef) \ and meth.name.startswith('test'): if meth.name in seen_methods: raise ValueError("Duplicate method %s in %s" \ % (meth.name, fname)) seen_methods[meth.name] = {} testname = get_test_name(meth, fname, cls.name, meth.name, ast.get_docstring(meth, False)) out.append("%s.%s %s" % (cls.name, meth.name, testname)) return out
def tag_class_functions(self, cls_node): """Tag functions if they are methods, classmethods, staticmethods""" # tries to find all 'old style decorators' like # m = staticmethod(m) late_decoration = {} for node in iter_child_nodes(cls_node): if not (isinstance(node, ast.Assign) and isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name)): continue func_name = node.value.func.id if func_name in ('classmethod', 'staticmethod'): meth = (len(node.value.args) == 1 and node.value.args[0]) if isinstance(meth, ast.Name): late_decoration[meth.id] = func_name # iterate over all functions and tag them for node in iter_child_nodes(cls_node): if not isinstance(node, ast.FunctionDef): continue node.function_type = 'method' if node.name == '__new__': node.function_type = 'classmethod' if node.name in late_decoration: node.function_type = late_decoration[node.name] elif node.decorator_list: names = [d.id for d in node.decorator_list if isinstance(d, ast.Name) and d.id in ('classmethod', 'staticmethod')] if names: node.function_type = names[0]
def visit_ClassDef(self, node): self.transforms = {} self.in_class_define = True functions_to_promote = [] setup_func = None for class_func in ast.iter_child_nodes(node): if isinstance(class_func, ast.FunctionDef): if class_func.name == 'setup': setup_func = class_func for anon_func in ast.iter_child_nodes(class_func): if isinstance(anon_func, ast.FunctionDef): functions_to_promote.append(anon_func) if setup_func: for func in functions_to_promote: setup_func.body.remove(func) func.args.args.insert(0, ast.Name(id='self', ctx=ast.Load())) node.body.append(func) self.transforms[func.name] = 'self.' + func.name ast.fix_missing_locations(node) self.generic_visit(node) return node
def process_file(self): result = [] with open(self.path) as f: current_node = ast.parse(f.read()) pathes = os.path.splitext(self.path)[0] pathes = pathes.split("/") base_path = ".".join(pathes) nodes = [(i, base_path) for i in ast.iter_child_nodes(current_node)] while len(nodes) > 0: current_node, path = nodes.pop(0) if isinstance(current_node, ast.Import): module = '' elif isinstance(current_node, ast.ImportFrom): module = current_node.module elif isinstance(current_node, ast.FunctionDef) or isinstance(current_node, ast.ClassDef): path += "." + current_node.name next_nodes = [(i, path) for i in ast.iter_child_nodes(current_node)] nodes.extend(next_nodes) continue else: continue for n in current_node.names: result.append(self.node_class(module=module, full_path=path, name=n.name, alias=n.asname)) return result
def visit_Attribute(self,node): if self.messagenode is not None: repnode = list(ast.walk(self.messagenode)) if node in repnode: self.generic_visit(node) return self.messagenode = node t = ast.iter_child_nodes(node) res = [[t,0]] maxcount = 1 while len(res) >= 1: t = res[-1][0] for childnode in t: # print childnode if isinstance(childnode, _ast.Attribute): res[-1][1] = 1 else: res[-1][1] = 0 res.append([ast.iter_child_nodes(childnode),0]) break else: maxcount = max(maxcount,sum([flag for (item,flag) in res]) + 2) res.pop() continue # print maxcount if maxcount >= config['messagechain']: self.result.append((13,self.fileName,node.lineno,maxcount)) self.generic_visit(node)
def test_iter_child_nodes(self): node = ast.parse("spam(23, 42, eggs='leek')", mode="eval") self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4) iterator = ast.iter_child_nodes(node.body) self.assertEqual(next(iterator).id, "spam") self.assertEqual(next(iterator).n, 23) self.assertEqual(next(iterator).n, 42) self.assertEqual(ast.dump(next(iterator)), "keyword(arg='eggs', value=Str(s='leek'))")
def walklocal(root): """Recursively yield all descendant nodes but not in a different scope""" todo = collections.deque(ast.iter_child_nodes(root)) yield root, False while todo: node = todo.popleft() newscope = isinstance(node, ast.FunctionDef) if not newscope: todo.extend(ast.iter_child_nodes(node)) yield node, newscope
def test_iter_child_nodes(self): node = ast.parse("spam(23, 42, eggs='leek')", mode='eval') self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4) iterator = ast.iter_child_nodes(node.body) self.assertEqual(next(iterator).id, 'spam') self.assertEqual(next(iterator).c, 23) self.assertEqual(next(iterator).c, 42) self.assertEqual(ast.dump(next(iterator)), "keyword(arg='eggs', value=Const(c='leek', constant=pure_const()))" )
def find_global_defs(self, func_def_node): global_names = set() nodes_to_check = deque(iter_child_nodes(func_def_node)) while nodes_to_check: node = nodes_to_check.pop() if isinstance(node, ast.Global): global_names.update(node.names) if not isinstance(node, (ast.FunctionDef, ast.ClassDef)): nodes_to_check.extend(iter_child_nodes(node)) func_def_node.global_names = global_names
def visit_FunctionDef(self,node): # argsCount def findCharacter(s,d): try: value = s.index(d) except ValueError: return -1 else: return value funcName = node.name.strip() p = re.compile("^(__[a-zA-Z0-9]+__)$") if p.match(funcName.strip()) and funcName != "__import__" and funcName != "__all__": self.defmagic.add((funcName,self.fileName,node.lineno)) stmt = astunparse.unparse(node.args) arguments = stmt.split(",") argsCount = 0 for element in arguments: if findCharacter(element,'=') == -1: argsCount += 1 self.result.append((1,self.fileName,node.lineno,argsCount)) #function length lines = set() res = [node] while len(res) >= 1: t = res[0] for n in ast.iter_child_nodes(t): if not hasattr(n,'lineno') or ((isinstance(t,_ast.FunctionDef) or isinstance(t,_ast.ClassDef)) and n == t.body[0] and isinstance(n,_ast.Expr)): continue lines.add(n.lineno) if isinstance(n,_ast.ClassDef) or isinstance(n,_ast.FunctionDef): continue else: res.append(n) del res[0] self.result.append((2,self.fileName,node.lineno,len(lines))) #nested scope depth if node in self.scopenodes: self.scopenodes.remove(node) self.generic_visit(node) return dep = [[node,1]] #node,nestedlevel maxlevel = 1 while len(dep) >= 1: t = dep[0][0] currentlevel = dep[0][1] for n in ast.iter_child_nodes(t): if isinstance(n,_ast.FunctionDef): self.scopenodes.append(n) dep.append([n,currentlevel+1]) maxlevel = max(maxlevel,currentlevel) del dep[0] if maxlevel>1: self.result.append((3,self.fileName,node.lineno,maxlevel)) #DOC self.generic_visit(node)
def generic_visit(self, node): name=typename(node) self.debug('debug; visiting: %s'%str(name)) if name == 'Module': self.graph+='digraph G {'+self.delim ast.NodeVisitor.generic_visit(self, node) # TODO: print labels for n in self.nodenumbers.keys(): self.graph+='%s [label="%s"];%s' % (self.getNodeNumber(n), self.getNodeType(n), self.delim) self.graph+='}' elif name == 'FunctionDef': # todo: number of params # always create downward links as needed for n in ast.iter_child_nodes(node): if typename(n) in ('For','If','Return'): self.debug('%s -> %s' % (str(name),typename(n))) self.makeLink(node, n) ast.NodeVisitor.generic_visit(self, node) elif name == 'For': for n in ast.iter_child_nodes(node): if typename(n) in ('For', 'If', 'Return'): self.debug('%s -> %s' % (name,typename(n))) self.makeLink(node, n) ast.NodeVisitor.generic_visit(self, node) elif name == 'If': # check the then branch found=False for n in node.body: if typename(n) in ('If','For','Return'): found=True self.makeLink(node,n) if not found: self.makeLink(node,node) if (len(node.orelse)>0): found=False for n in node.orelse: if typename(n) in ('If','For','Return'): found=True self.makeLink(node,n) if not found: self.graph += '%s -> %s;%s' % (self.getNodeNumber(node), self.getNodeNumber(node), self.delim) #print('body', node.body, typename(node.body), typename(node.orelse)) for n in ast.iter_child_nodes(node): self.debug('%s ===> %s' % (str(name), typename(n))) ast.NodeVisitor.generic_visit(self, node) elif name == 'Return': # probably no need for recursive step here??? #print(name, node._fields) ast.NodeVisitor.generic_visit(self, node) else: #print('CATCHALL', name) ast.NodeVisitor.generic_visit(self, node)
def nodes_equal(x, y): __tracebackhide__ = True assert type(x) == type(y) if isinstance(x, (ast.Expr, ast.FunctionDef, ast.ClassDef)): assert x.lineno == y.lineno assert x.col_offset == y.col_offset for (xname, xval), (yname, yval) in zip(ast.iter_fields(x), ast.iter_fields(y)): assert xname == yname assert type(xval) == type(yval) for xchild, ychild in zip(ast.iter_child_nodes(x), ast.iter_child_nodes(y)): assert nodes_equal(xchild, ychild) return True
def find_node(self, node, name=None, tipe=None): def name_match(): if name is None: return True else: if 'name' in node.__dict__: return name == node.__dict__['name'] return False def type_match(): if tipe is None: return True else: return isinstance(node, tipe) if type_match() and name_match(): return node for child_node in ast.iter_child_nodes(node): found_node = self.find_node(child_node, name=name, tipe=tipe) if found_node is not None: return found_node return None
def simplifyTree(tree): resultList = [] for node in list(ast.iter_child_nodes(tree)): simplenode = {} vars(node).pop('lineno') vars(node).pop('col_offset') #print(vars(node)) for key, value in vars(node).items(): if(isinstance(value, ast.Num)): simplenode[key] = simplifyNum(value) elif(isinstance(value, ast.Name)): simplifyName(value) elif(isinstance(value, ast.Str)): simplenode[key] = simplifyStr(value) elif(key is 'targets'): simplenode[key] = simplifyName(value[0]) elif(isinstance(value, ast.Call)): simplenode[key] = simplifyCall(value) elif(isinstance(value, ast.List)): simplenode[key] = simplifyList(value) elif(isinstance(value, ast.Dict)): simplifyDict(value) else: print(type(value)) resultList.append(simplenode) return resultList
def _parse_file(self, fp, root): """Parse a file and return all normalized skbio imports.""" imports = [] with open(fp, 'U') as f: # Read the file and run it through AST source = ast.parse(f.read()) # Get each top-level element, this is where API imports should be. for node in ast.iter_child_nodes(source): if isinstance(node, ast.Import): # Standard imports are easy, just get the names from the # ast.Alias list `node.names` imports += [x.name for x in node.names] elif isinstance(node, ast.ImportFrom): prefix = "" # Relative import handling. if node.level > 0: prefix = root extra = node.level - 1 while(extra > 0): # Keep dropping... prefix = os.path.split(prefix)[0] extra -= 1 # We need this in '.' form not '/' prefix = prefix.replace(os.sep, ".") + "." # Prefix should be empty unless node.level > 0 imports += [".".join([prefix + node.module, x.name]) for x in node.names] skbio_imports = [] for import_ in imports: # Filter by skbio if import_.split(".")[0] == "skbio": skbio_imports.append(import_) return skbio_imports
def generic_visit(self, node, inside_call=False, inside_yield=False): if isinstance(node, ast.Call): inside_call = True elif isinstance(node, ast.Yield): inside_yield = True elif isinstance(node, ast.Str): if disable_lookups: if inside_call and node.s.startswith("__"): # calling things with a dunder is generally bad at this point... raise AnsibleError( "Invalid access found in the conditional: '%s'" % conditional ) elif inside_yield: # we're inside a yield, so recursively parse and traverse the AST # of the result to catch forbidden syntax from executing parsed = ast.parse(node.s, mode='exec') cnv = CleansingNodeVisitor() cnv.visit(parsed) # iterate over all child nodes for child_node in ast.iter_child_nodes(node): self.generic_visit( child_node, inside_call=inside_call, inside_yield=inside_yield )
def _fix(node, lineno, col_offset): if "lineno" in node._attributes: node.lineno = lineno if "col_offset" in node._attributes: node.col_offset = col_offset for child in ast.iter_child_nodes(node): _fix(child, lineno, col_offset)
def wrapper(self, node): result = fn(self, node) for child in ast.iter_child_nodes(node): self.visit(child) return result
def visit_children(self, node): if isinstance(node, list): for node in node: self.visit_deps(node) else: for node in ast.iter_child_nodes(node): self.visit_deps(node)
def _extract_docstrings(self, bzl_file): """Extracts the docstrings for all public rules in the .bzl file. This function parses the .bzl file and extracts the docstrings for all public rules in the file that were extracted in _process_skylark. It calls _add_rule_doc for to parse the attribute documentation in each docstring and associate them with the extracted rules and attributes. Args: bzl_file: The .bzl file to extract docstrings from. """ try: tree = None with open(bzl_file) as f: tree = ast.parse(f.read(), bzl_file) key = None for node in ast.iter_child_nodes(tree): if isinstance(node, ast.Assign): name = node.targets[0].id if not name.startswith("_"): key = name continue elif isinstance(node, ast.Expr) and key: # Python itself does not treat strings defined immediately after a # global variable definition as a docstring. Only extract string and # parse as docstring if it is defined. if hasattr(node.value, 's'): self._add_rule_doc(key, node.value.s.strip()) key = None except IOError: print("Failed to parse {0}: {1}".format(bzl_file, e.strerror)) pass
def _get_ordered_child_nodes(node): if isinstance(node, ast.Dict): children = [] for i in range(len(node.keys)): children.append(node.keys[i]) children.append(node.values[i]) return children elif isinstance(node, ast.Call): children = [node.func] + node.args for kw in node.keywords: children.append(kw.value) # TODO: take care of Python 3.5 updates (eg. args=[Starred] and keywords) if hasattr(node, "starargs") and node.starargs is not None: children.append(node.starargs) if hasattr(node, "kwargs") and node.kwargs is not None: children.append(node.kwargs) children.sort(key=lambda x: (x.lineno, x.col_offset)) return children elif isinstance(node, ast.arguments): children = node.args + node.kwonlyargs + node.kw_defaults + node.defaults if node.vararg is not None: children.append(node.vararg) if node.kwarg is not None: children.append(node.kwarg) children.sort(key=lambda x: (x.lineno, x.col_offset)) return children else: return ast.iter_child_nodes(node)
def rvals(node): if is_name(node): for k in targets: targets[k].add(node) else: for node in ast.iter_child_nodes(node): rvals(node)
def find_attribute_with_name(node, name): if isinstance(node, ast.Attribute) and node.attr == name: return node for item in ast.iter_child_nodes(node): r = find_attribute_with_name(item, name) if r: return r
def _statement_assign(self, env, node): itr = self._fold_expr(env, node.value) while itr: try: yield next(itr) except StopIteration: break val = env.pop() # TODO what else is supposed to hold?? assert len(node.targets) == 1 target = node.targets[0] # fold the lists if isinstance(val, list): l = [] for v in val: itr = self._fold_expr(env, v) while itr: try: yield next(itr) except StopIteration: break l.append(env.pop()) val = l # assign to a list of variables if (isinstance(target, ast.List) or isinstance(target, ast.Tuple)) and isinstance(val, list): i = 0 for var in ast.iter_child_nodes(target): if isinstance(var, ast.Name): env.setvar(var.id, val[i]) i += 1 else: env.setvar(target.id, val) yield node
def recurse(node: ast.Expression, temp: list): ''' Local function for recursion over all leaves in the tree. Parameters: node - current node and all its children temp - array where the answer lives ''' if isinstance(node, ast.BinOp): # go to the left part of expr recurse(node.left, temp) # go to the operation of the expr recurse(node.op, temp) # go to the right part of expr recurse(node.right, temp) # if it's the counter if isinstance(node, ast.Name): temp.append([node.id]) # is other cases going though all children of the expresion else: for child in ast.iter_child_nodes(node): recurse(child, temp) return temp
def __get_imports(self, file_path): """ Extract all import statements from the provided file including aliases. :param path: file to be analyzed :return: list of `ImportedModule` """ # get the AST of the provided file with open(file_path) as file: root = ast.parse(file.read(), file_path) # extract import statements for node in ast.iter_child_nodes(root): if isinstance(node, ast.Import): module = '' elif isinstance(node, ast.ImportFrom): module = node.module else: continue for n in node.names: yield ImportedModule(module, n.name, n.asname, node.lineno, node.col_offset, None)
def get_contrib_requirements(filepath: str) -> Dict: """ Parse the python file from filepath to identify a "library_metadata" dictionary in any defined classes, and return a requirements_info object that includes a list of pip-installable requirements for each class that defines them. Note, currently we are handling all dependencies at the module level. To support future expandability and detail, this method also returns per-class requirements in addition to the concatenated list. Args: filepath: the path to the file to parse and analyze Returns: A dictionary: { "requirements": [ all_requirements_found_in_any_library_metadata_in_file ], class_name: [ requirements ] } """ with open(filepath) as file: tree = ast.parse(file.read()) requirements_info = {"requirements": []} for child in ast.iter_child_nodes(tree): if not isinstance(child, ast.ClassDef): continue current_class = child.name for node in ast.walk(child): if isinstance(node, ast.Assign): try: target_ids = [target.id for target in node.targets] except (ValueError, AttributeError): # some assignment types assign to non-node objects (e.g. Tuple) target_ids = [] if "library_metadata" in target_ids: library_metadata = ast.literal_eval(node.value) requirements = library_metadata.get("requirements", []) requirements_info[current_class] = requirements requirements_info["requirements"] += requirements return requirements_info
def class_parse(args): if args.cl: objsearch = 'ClassDef' elif args.func: objsearch = 'FunctionDef' with open(args.path) as f: content = f.read() results = [] added_lines = [] root = ast.parse(content) for node in ast.walk(root): for child in ast.iter_child_nodes(node): child.parent = node # search for pattern for num, line in enumerate(content.splitlines(), 1): if (args.pattern in line or (args.regexp and re.search(args.pattern, line))) and ( num, line) not in added_lines: pattern_node = find_match_node(results, num, root, args) if pattern_node is None: continue else: while objsearch not in str(pattern_node): if pattern_node.parent is root: break pattern_node = pattern_node.parent curres = [] if objsearch in str(pattern_node): first = pattern_node.lineno end = get_end(pattern_node) curres += [ mhighlight(num, line, args.pattern, args.regexp) for num, line in enumerate( content.splitlines()[first - 1:end], first) ] added_lines += [(num, line) for num, line in enumerate( content.splitlines()[first - 1:end], first)] results.append(''.join(curres)) return results
def _iter_child_nodes_in_order_internal_1(node): if not isinstance(node, ast.AST): raise TypeError if isinstance(node, ast.Dict): assert node._fields == ("keys", "values") yield list(zip(node.keys, node.values)) elif isinstance(node, ast.FunctionDef): if six.PY2: assert node._fields == ('name', 'args', 'body', 'decorator_list') else: assert node._fields == ('name', 'args', 'body', 'decorator_list', 'returns') yield node.decorator_list, node.args, node.body # node.name is a string, not an AST node elif isinstance(node, ast.arguments): if six.PY2: assert node._fields == ('args', 'vararg', 'kwarg', 'defaults') else: assert node._fields == ('args', 'vararg', 'kwonlyargs', 'kw_defaults', 'kwarg', 'defaults') defaults = node.defaults or () num_no_default = len(node.args) - len(defaults) yield node.args[:num_no_default] yield list(zip(node.args[num_no_default:], defaults)) # node.varags and node.kwarg are strings, not AST nodes. elif isinstance(node, ast.IfExp): assert node._fields == ('test', 'body', 'orelse') yield node.body, node.test, node.orelse elif isinstance(node, ast.ClassDef): if six.PY2: assert node._fields == ('name', 'bases', 'body', 'decorator_list') else: assert node._fields == ('name', 'bases', 'keywords', 'body', 'decorator_list') yield node.decorator_list, node.bases, node.body # node.name is a string, not an AST node else: # Default behavior. yield ast.iter_child_nodes(node)
def _get_ast_trees(self, repository_path): filenames = [] main_file_contents = [] ast_trees = [] for dirname, _, files in os.walk(repository_path, topdown=True): for file in files: if file.endswith('.py'): filenames.append(os.path.join(dirname, file)) for filename in filenames: with open(filename, 'r', encoding='utf-8') as file_handler: main_file_content = file_handler.read() try: tree = ast.parse(main_file_content) for node in ast.walk(tree): for child in ast.iter_child_nodes(node): child.parent = node except SyntaxError as e: print(e) tree = None main_file_contents.append(main_file_content) ast_trees.append(tree) return filenames, main_file_contents, ast_trees
def visit_While(self, node): if node and not config.mutated: for child in ast.iter_child_nodes(node): config.parent_dict[child] = node if self.operator[1] is StatementDeletion: for anode in node.body: if anode.__class__ in [ ast.Raise, ast.Continue, ast.Break, ast.Assign, ast.AugAssign, ast.Call ]: config.nodes_to_remove.add(anode) elif anode.__class__ in [ast.Expr]: config.nodes_to_potential.add(anode) node = self.mutate_single_node(node, self.operator) else: node = self.mutate_single_node(node, self.operator) if node and not config.mutated: self.dfs_visit(node) elif node and config.mutated and config.recovering: return self.recover_node(node) return node
def checkGlobalIds(a, l): if not isinstance(a, ast.AST): return elif type(a) in [ ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param ]: return if not hasattr(a, "global_id"): addedNodes = [ "propagatedVariable", "orderedBinOp", "augAssignVal", "augAssignBinOp", "combinedConditional", "combinedConditionalOp", "multiCompPart", "multiCompOp", "second_global_id", "moved_line", # above this line has individualize functions. below does not. "addedNot", "addedNotOp", "addedOther", "addedOtherOp", "collapsedExpr", "removedLines", "helperVar", "helperReturn", "typeCastFunction", ] for t in addedNodes: if hasattr(a, t): break else: # only enter the else if none of the provided types are an attribute of a log( "canonicalize\tcheckGlobalIds\tNo global id: " + str(l) + "," + str(a.__dict__) + "," + printFunction(a, 0), "bug") for f in ast.iter_child_nodes(a): checkGlobalIds(f, l + [type(a)])
def process_child_nodes( node: ast.AST, increment_by: int, verbose: bool, complexity_calculator: Callable, ) -> int: child_complexity = 0 child_nodes = ast.iter_child_nodes(node) for node_num, child_node in enumerate(child_nodes): if isinstance(node, ast.Try): if node_num == 1: # add +1 for all try nodes except body increment_by += 1 if node_num: child_complexity += max(1, increment_by) child_complexity += complexity_calculator( child_node, increment_by=increment_by, verbose=verbose, ) return child_complexity
def visit_FunctionDef(self, node): # self._filename is 'stdin' in the unit test for this check. if (not os.path.basename(self._filename).startswith('test_') and not 'stdin'): return closures = [] references = [] # Walk just the direct nodes of the test method for child_node in ast.iter_child_nodes(node): if isinstance(child_node, ast.FunctionDef): closures.append(child_node.name) # Walk all nodes to find references find_references = _FindVariableReferences() find_references.generic_visit(node) references = find_references._references missed = set(closures) - set(references) if missed: self.add_error(node, 'N349: Test closures not called: %s' % ','.join(missed))
def _prove_names_defined( env: SymbolTable, names: tp.AbstractSet[str], node: tp.Union[ast.AST, tp.Sequence[ast.AST]]) -> tp.AbstractSet[str]: ''' Prove that all names are defined. i.e. if a name is used at some point then all paths leading to that point must either return or define the name. ''' names = set(names) if isinstance(node, ast.Name): if isinstance(node.ctx, ast.Store): names.add(node.id) elif node.id not in names and node.id not in env and \ node.id not in env["__builtins__"].__dict__: if hasattr(node, 'lineno'): raise SyntaxError(f'Cannot prove name, {node.id}, is defined at line {node.lineno}') else: raise SyntaxError(f'Cannot prove name, {node.id}, is defined') elif isinstance(node, ast.If): t_returns = _always_returns(node.body) f_returns = _always_returns(node.orelse) t_names = _prove_names_defined(env, names, node.body) f_names = _prove_names_defined(env, names, node.orelse) if not (t_returns or f_returns): names |= t_names & f_names elif t_returns: names |= f_names elif f_returns: names |= t_names elif isinstance(node, ast.AST): for child in ast.iter_child_nodes(node): names |= _prove_names_defined(env, names, child) else: assert isinstance(node, tp.Sequence) for child in node: names |= _prove_names_defined(env, names, child) return names
def _move_var_decls_to_top_of_scope(node, vars_already_scoped={}): """Initilizes all variables used in a scope to None at the begining of that scope. Parameters: node: ast node vars_already_scoped: hash table of variables that have already been reinitilized Returns: node with variables initialized at top of scope. vars_already_scoped: hashtable with variables already moved.""" current_scope_new_vars = [] if isinstance(node, ast.FunctionDef) or isinstance(node, ast.Module): for child_node in ast.iter_child_nodes(node): new_vars, vars_already_scoped = _find_vars_in_same_scope( child_node, vars_already_scoped) current_scope_new_vars = current_scope_new_vars + new_vars current_scope_new_vars.reverse() _move_initilizations_to_top_of_scope(node, current_scope_new_vars) return node, vars_already_scoped
def get_test_methods(test_path: str) -> List[Any]: """Gets the top-level methods within a test file Args: source_path: path to the file to process Returns: List[ast.AST]: a list of the top-level methods within the provided file """ try: with open(test_path, 'r') as file: content = ''.join(file.readlines()) parsed_nodes = list(ast.iter_child_nodes(ast.parse(content))) test_nodes = _get_test_nodes(parsed_nodes) for node in test_nodes: node.test_path = os.path.abspath(test_path) # Verify file contains no duplicate method names # (Only relevant for test methods wrapped in classes) used_test_names = set() for node in test_nodes: if node.name in used_test_names: raise ValueError(f'Test name {node.name} in file' f' {test_path} must be unique.') used_test_names.add(node.name) return test_nodes except IOError as err: # Fail gracefully if a file can't be read # (This shouldn't happen, but if it doess # we don't want to 'break the build'.) sys.stderr.write(f'WARNING: could not read file: {test_path}\n') sys.stderr.write(f'\t{str(err)}\n') return []
def visit_Name(self, node): self.currentAttribute.append(node.id) fullyQualified = '.'.join(reversed(self.currentAttribute)) self._nnPrint( node, '%s --> %s' % (node.id, fullyQualified) ) if not self.checkScopeAllowed(reversed(self.currentAttribute)): if '.' in fullyQualified \ or fullyQualified in self.TRUSTED_BUILTINS: self.disallowed.append(fullyQualified) # allow assignments in iterators elif self.iterNodeChildren: if node in self.iterNodeChildren: self.iterNodeChildren.remove(node) # naive pass - just accept the name self.LOCAL_VARIABLES.add(node.id) # strictly allow assignments that are single-label elif self.assignmentNodeChildren and node in self.assignmentNodeChildren \ and any([isinstance(child, ast.Store) for child in ast.iter_child_nodes(node) ]) : self.LOCAL_VARIABLES.add(fullyQualified) else: self.disallowed.append(fullyQualified) else: rootObj = self.currentAttribute[-1] if rootObj in Validator.TRUSTED_IMPORTS: self.requiredImports.add(rootObj) # clear for the next self.currentAttribute = [] self.assignmentNodeChildren = None
def lstm_unit(self, ast_node, depth=0): """ Process of one LSTM unit. Recursively calls learning processes on all children in one tree : param ast_node: one Python AST node; First call will be with root Node : returns: hidden state and context of node; eventually for the whole AST """ weight = torch.tensor([]) # TODO weights with lstm calculation!! w_t = ast2vec(ast_node, self.dictionary, self.emb_matrix) # embedding of tree # sum of children hidden outputs h_ = 0 # child hidden state h_k = 0 # context of child c_k = 0 # forget gates f_tk = 0 # childrem forgetrates times the context c_ = 0 for k in ast.iter_child_nodes(ast_node): print(k, depth) h_k, c_k = self.lstm_unit(k, depth + 1) f_tk = torch.nn.Sigmoid()(weight) h_ += h_k c_ += (f_tk * c_k) # input gate i_t = torch.nn.Sigmoid()(weight) # vector of new candidate values for t c_t_ = torch.nn.Tanh()(weight) # context c_t = i_t * c_t_ + c_ # output gate o_t = torch.nn.Sigmoid()(weight) h_t = o_t * torch.nn.Tanh()(c_t) return h_t, c_t
def _fix(node, lineno, col_offset, end_lineno, end_col_offset): if 'lineno' in node._attributes: if not hasattr(node, 'lineno'): node.lineno = lineno else: lineno = node.lineno if 'end_lineno' in node._attributes: if not hasattr(node, 'end_lineno'): node.end_lineno = end_lineno else: end_lineno = node.end_lineno if 'col_offset' in node._attributes: if not hasattr(node, 'col_offset'): node.col_offset = col_offset else: col_offset = node.col_offset if 'end_col_offset' in node._attributes: if not hasattr(node, 'end_col_offset'): node.end_col_offset = end_col_offset else: end_col_offset = node.end_col_offset for child in iter_child_nodes(node): _fix(child, lineno, col_offset, end_lineno, end_col_offset)
def _get_ordered_child_nodes(node): if isinstance(node, ast.Dict): children = [] for i in range(len(node.keys)): children.append(node.keys[i]) children.append(node.values[i]) return children elif isinstance(node, ast.Call): children = [node.func] + node.args for kw in node.keywords: children.append(kw.value) # TODO: take care of Python 3.5 updates (eg. args=[Starred] and keywords) if hasattr(node, "starargs") and node.starargs is not None: children.append(node.starargs) if hasattr(node, "kwargs") and node.kwargs is not None: children.append(node.kwargs) children.sort(key=lambda x: (x.lineno, x.col_offset)) return children else: return ast.iter_child_nodes(node)
def check_module( module: ast.Module, ignore_ambiguous_signatures: bool = True ) -> Iterator[Tuple[ast.AST, List[str], List[str]]]: """Check a module. Parameters ---------- module : ast.Module The module in which to check functions and classes. ignore_ambiguous_signatures : bool, optional Whether to ignore extra documented arguments if the function as an ambiguous (*args / **kwargs) signature (the default is True). Returns ------- dict """ for node in ast.iter_child_nodes(module): if not is_private(node): check_result = check(node, ignore_ambiguous_signatures) if check_result is not None: yield from check_result
def find_lca(root, node1, node2): if root == None: return None if root == node1 or root == node2: return root # look for lca in all subtrees lca_list = [] for child_node in ast.iter_child_nodes(root): if child_node not in IGNORE_TYPES: lca_list.append(find_lca(child_node, node1, node2)) # Remove None in the list lca_list = [i for i in lca_list if i != None] # print('for node:', type(root).__name__) # print(lca_list) # If two of the above calls returns non-None, then two nodes are present in separate subtree # if len(lca_list) == 2: if len(lca_list) >= 2: return root # If above calls return one non-None, then two nodes exit in one subtree if len(lca_list) == 1: return lca_list[0] return None
def walk(node, stop_at=tuple(), ignore=tuple()): """Walk through the children of an ast node. Args: node: an ast node stop_at: stop traversing through these nodes, including the matching node ignore: stop traversing through these nodes, excluding the matching node Returns: a generator of ast nodes """ todo = deque([node]) while todo: node = todo.popleft() if isinstance(node, ignore): # dequeue next node continue if not isinstance(node, stop_at): next_nodes = ast.iter_child_nodes(node) for n in next_nodes: todo.extend([n]) yield node
def createNameMap(a, d=None): if d == None: d = { } if not isinstance(a, ast.AST): return d if type(a) == ast.Module: # Need to go through the functions backwards to make this right for i in range(len(a.body) - 1, -1, -1): createNameMap(a.body[i], d) return d if type(a) in [ast.FunctionDef, ast.ClassDef]: if hasattr(a, "originalId") and a.name not in d: d[a.name] = a.originalId elif type(a) == ast.arg: if hasattr(a, "originalId") and a.arg not in d: d[a.arg] = a.originalId return d elif type(a) == ast.Name: if hasattr(a, "originalId") and a.id not in d: d[a.id] = a.originalId return d for child in ast.iter_child_nodes(a): createNameMap(child, d) return d
def get_imports(root: ast.AST): # https://stackoverflow.com/a/9049549 Import = collections.namedtuple("Import", ["module", "name", "alias"]) result = [] for node in ast.iter_child_nodes(root): if isinstance(node, ast.Import): module = [] elif isinstance(node, ast.ImportFrom): module = node.module.split('.') elif isinstance(node, (ast.FunctionDef, ast.ClassDef)): function_result = get_imports(node) if function_result: result += function_result continue else: continue for n in node.names: result.append(Import(module, n.name.split('.'), n.asname)) return result
def _relevant_statements(ancestors: List[ast.AST]) -> Iterable[ast.AST]: """ Given a list of ancestors, finds the statements to analyze. Given a list of ancestors, as described by :func:`_ast_ancestors`, finds all statements in the tree that are considered "worthy to be analyzed". In general, these are all child nodes of any ancestor that represent import statements or assign statements. :param ancestors: The list of ancestors, as described by :func:`_ast_ancestors`. :return: An iteration over the relevant nodes. """ acceptable_types = [ast.Import, ast.ImportFrom, ast.Assign] for ancestor in ancestors: for child in ast.iter_child_nodes(ancestor): if any(isinstance(child, t) for t in acceptable_types): yield child elif isinstance(child, ast.Try): for grandchild in itertools.chain( child.body, child.finalbody, child.orelse ): if any(isinstance(grandchild, t) for t in acceptable_types): yield grandchild
def get_meta(filename): """Get top level module metadata without execution. """ result = { '__file__': filename, '__name__': os.path.splitext(os.path.basename(filename))[0], '__package__': '', } with open(filename) as fp: root = ast.parse(fp.read(), fp.name) result['__doc__'] = ast.get_docstring(root) for node in ast.iter_child_nodes(root): if isinstance(node, ast.Assign): for target in node.targets: try: result[target.id] = ast.literal_eval(node.value) except ValueError: pass return result
def get_name_import_path(name_node: ast.Name, pyfilepath: str) -> Optional[EntityImportInfo]: current_node = name_node.parent # type: ignore while True: for child in ast.iter_child_nodes(current_node): # check for Import, not only ImportFrom if isinstance( child, ast.ImportFrom) and name_node.id in (n.name for n in child.names): return extract_import_info_from_import_node(child, name_node) elif (isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == name_node.id): return { 'file_path': pyfilepath if pyfilepath != 'built-in' else None, 'import_path': None, 'name': name_node.id, } current_node = current_node.parent if not current_node: break return None
def parse_flags(synth_py_path="synth.py") -> typing.Dict[str, typing.Any]: """Extracts flags from a synth.py file. Keyword Arguments: synth_py_path {str or Path} -- Path to synth.py (default: {"synth.py"}) Returns: typing.Dict[str, typing.Any] -- A map of all possible flags. Flags not found in synth.py will be set to their default value. """ flags = copy.copy(_SYNTH_PY_FLAGS) path = pathlib.Path(synth_py_path) try: module = ast.parse(path.read_text()) except SyntaxError: return flags # Later attempt to execute synth.py will give a clearer error message. for child in ast.iter_child_nodes(module): if isinstance(child, ast.Assign): for target in child.targets: if isinstance(target, ast.Name) and target.id in flags: flags[target.id] = ast.literal_eval(child.value) return flags
def embed_subtree(self, node): children_embeddings = [] # c_0 = torch.zeros(1, self.embedding_dim) # h_0 = torch.zeros(1, self.embedding_dim) # n_children = 0 # print("Embedding the subtree of ", node) # print("\t\t\t", list(ast.iter_child_nodes(node))) for child in ast.iter_child_nodes(node): # print(child) child_embedding = self.embed_node(child).view(1, -1) children_embeddings.append(child_embedding) # (c_0, h_0) = self.subtree_network(child_embedding, (c_0, h_0)) # h_0 += child_embedding # n_children += 1 # h_0 /= n_children # indices = torch.from_numpy() # embeddings = self.embedding_layer(indices) # children_embeddings = torch.cat(children_embeddings, 0)[None, :, :] # print(children_embeddings) return children_embeddings
def exec_eval(script, globals=None, locals=None, name=''): '''Execute a script and return the value of the last expression''' stmts = list(ast.iter_child_nodes(ast.parse(script))) if not stmts: return None if isinstance(stmts[-1], ast.Expr): # the last one is an expression and we will try to return the results # so we first execute the previous statements if len(stmts) > 1: if sys.version_info >= (3, 8): mod = ast.Module(stmts[:-1], []) else: mod = ast.Module(stmts[:-1]) exec(compile(mod, filename=name, mode="exec"), globals, locals) # then we eval the last one return eval( compile(ast.Expression(body=stmts[-1].value), filename=name, mode="eval"), globals, locals) else: # otherwise we just execute the entire code return exec(compile(script, filename=name, mode="exec"), globals, locals)
def get_info(a, depth=0): "Print detailed information about an AST" nm = a.__class__.__name__ print(" " * depth, end="") iter_children = True if nm == "Num": if type(a.n) == int: print("%s=%d" % (nm, a.n)) else: print("%s=%f" % (nm, a.n)) elif nm == "Global": print("Global:", dir(a)) elif nm == "Str": print("%s='%s'" % (nm, a.s)) elif nm == "Name": print("%s='%s'" % (nm, a.id)) elif nm == "arg": print("%s='%s'" % (nm, a.arg)) elif nm == "If": iter_children = False print(nm) get_info(a.test, depth) for n in a.body: get_info(n, depth + 1) if len(a.orelse) > 0: print(" " * depth, end="") print("Else") for n in a.orelse: get_info(n, depth + 1) else: print(nm) for (f, v) in ast.iter_fields(a): if type(f) == str and type(v) == str: print("%s:attr[%s]=%s" % (" " * (depth + 1), f, v)) if iter_children: for n in ast.iter_child_nodes(a): get_info(n, depth + 1)
def set_node_context(tree: ast.AST) -> ast.AST: """ Used to set proper context to all nodes. What we call "a context"? Context is where exactly this node belongs on a global level. Example:: if some_value > 2: test = 'passed' Despite the fact ``test`` variable has ``Assign`` as it parent it will have ``Module`` as a context. What contexts do we respect? - :py:class:`ast.Module` - :py:class:`ast.ClassDef` - :py:class:`ast.FunctionDef` and :py:class:`ast.AsyncFunctionDef` """ contexts = ( ast.Module, ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, ) current_context = None for statement in ast.walk(tree): if isinstance(statement, contexts): current_context = statement for child in ast.iter_child_nodes(statement): setattr(child, 'context', current_context) return tree