Exemplo n.º 1
0
    def visit_Expr(self, node):
        if self.document is None:
            raise ValueError()

        children = list(ast.iter_child_nodes(node))
        if len(children) != 1:
            raise ValueError()

        feature_name = None

        child = children[0]
        if isinstance(child, ast.Attribute) \
                and child.attr in self.document.features:
            feature_name = child.attr
        else:
            raise ValueError()

        grandchildren = list(ast.iter_child_nodes(child))
        if len(grandchildren) != 2:
            raise ValueError()

        grandchild = grandchildren[0]
        if isinstance(grandchild, ast.Name) \
                and grandchild.id in self.locals \
                and isinstance(self.locals[grandchild.id], self.document):
            self.doc = self.locals[grandchild.id]
            self.feature_name = feature_name
        else:
            raise ValueError()
Exemplo n.º 2
0
def get_testcases(paths):
    """Walk each path in ``paths`` and return the test cases found.

    :param path: List o directories to find test modules and test cases.
    :return: A dict mapping a test module path and its test cases.
    """
    testmodules = []
    for path in paths:
        for dirpath, _, filenames in os.walk(path):
            for filename in filenames:
                if filename.startswith('test_') and filename.endswith('.py'):
                    testmodules.append(os.path.join(dirpath, filename))
    testcases = collections.OrderedDict()
    for testmodule in testmodules:
        testcases[testmodule] = []
        with open(testmodule) as handler:
            for node in ast.iter_child_nodes(ast.parse(handler.read())):
                if isinstance(node, ast.ClassDef):
                    # Class test methods
                    class_name = node.name
                    testcases[testmodule].extend([
                        TestFunction(subnode, class_name, testmodule)
                        for subnode in ast.iter_child_nodes(node)
                        if isinstance(subnode, ast.FunctionDef) and
                        subnode.name.startswith('test_')
                    ])
                elif (isinstance(node, ast.FunctionDef) and
                      node.name.startswith('test_')):
                    # Module's test functions
                    testcases[testmodule].append(TestFunction(
                        node, testmodule=testmodule))
    return testcases
Exemplo n.º 3
0
 def visit_Subscript(self,node):
   if self.subscriptnode is not None:
     repnode = list(ast.walk(self.subscriptnode))
     if node in repnode:
       self.generic_visit(node)
       return
   self.subscriptnode = node
   t = ast.iter_child_nodes(node)
   res = [[t,0]]
   maxcount = 1
   while len(res) >= 1:
     t = res[-1][0]
     for childnode in t:
       # print childnode
       if isinstance(childnode, _ast.Subscript) or isinstance(childnode, _ast.Tuple) or \
        isinstance(childnode, _ast.Dict) or isinstance(childnode, _ast.List) or isinstance(childnode, _ast.Set):
         res[-1][1] = 1
       else:
         res[-1][1] = 0
       res.append([ast.iter_child_nodes(childnode),0])
       break
     else:
       maxcount = max(maxcount,sum([flag for (item,flag) in res]) + 1)
       res.pop()
     continue
   # print maxcount
   if maxcount >= config['containerdepth']:
     self.result.append((6,self.fileName,node.lineno,maxcount))
   self.generic_visit(node) 
Exemplo n.º 4
0
Arquivo: tools.py Projeto: mitnk/xonsh
def nodes_equal(x, y):
    __tracebackhide__ = True
    assert type(x) == type(y), "Ast nodes do not have the same type: '%s' != '%s' " % (
        type(x),
        type(y),
    )
    if isinstance(x, (ast.Expr, ast.FunctionDef, ast.ClassDef)):
        assert (
            x.lineno == y.lineno
        ), "Ast nodes do not have the same line number : %s != %s" % (
            x.lineno,
            y.lineno,
        )
        assert x.col_offset == y.col_offset, (
            "Ast nodes do not have the same column offset number : %s != %s"
            % (x.col_offset, y.col_offset)
        )
    for (xname, xval), (yname, yval) in zip(ast.iter_fields(x), ast.iter_fields(y)):
        assert xname == yname, (
            "Ast nodes fields differ : %s (of type %s) != %s (of type %s)"
            % (xname, type(xval), yname, type(yval))
        )
        assert type(xval) == type(yval), (
            "Ast nodes fields differ : %s (of type %s) != %s (of type %s)"
            % (xname, type(xval), yname, type(yval))
        )
    for xchild, ychild in zip(ast.iter_child_nodes(x), ast.iter_child_nodes(y)):
        assert nodes_equal(xchild, ychild), "Ast node children differs"
    return True
Exemplo n.º 5
0
    def get_trait_definition(self):
        """ Retrieve the Trait attribute definition
        """
        # Get the class source and tokenize it.
        source = inspect.getsource(self.parent)

        nodes = ast.parse(source)
        for node in ast.iter_child_nodes(nodes):
            if isinstance(node, ClassDef):
                parent_node = node
                break
        else:
            return ''

        for node in ast.iter_child_nodes(parent_node):
            if isinstance(node, Assign):
                name = node.targets[0]
                if name.id == self.object_name:
                    break
        else:
            return ''

        endlineno = name.lineno
        for item in ast.walk(node):
            if hasattr(item, 'lineno'):
                endlineno = max(endlineno, item.lineno)

        definition_lines = [
            line.strip()
            for line in source.splitlines()[name.lineno-1:endlineno]]
        definition = ''.join(definition_lines)
        equal = definition.index('=')
        return definition[equal + 1:].lstrip()
Exemplo n.º 6
0
    def _iter_member_names(self):
        '''
        iterate over assign and def member names, including class bases
        preserves order (bases members goes first)
        '''
        all_bases_names = set(itertools.chain(*[dir(base) for base in self.__class__.__mro__]))
        for base in reversed(self.__class__.__mro__):
            if base.__name__ == 'object':
                continue
            
            class_src = '\n'.join(l for l in inspect.getsource(base).
                                  split('\n') if not l.lstrip().startswith('#'))
            classnodes = ast.iter_child_nodes(ast.parse(dedent(class_src)))
            names = []
            for classnode in classnodes:
                for node in ast.iter_child_nodes(classnode):
                    if isinstance(node, ast.FunctionDef):
                        if node.name in Behaviour._omit_member_names:
                            continue

                        names += node.name,
                        if node.name in all_bases_names:
                            all_bases_names.remove(node.name)
        
                    elif isinstance(node, ast.Assign):
                        for target in node.targets:
                            if target.id in Behaviour._omit_member_names:
                                continue
                            names += target.id,
                            if target.id in all_bases_names:
                                all_bases_names.remove(target.id)
            
            for name in names:
                if not name in all_bases_names:
                    yield name
Exemplo n.º 7
0
 def get_test_methods(fname):
     seen_classes = {}
     out = []
     lines = open(fname).readlines()
     lines = [x.rstrip('\r\n') for x in lines]
     a = ast.parse(("\n".join(lines)).rstrip() + '\n', fname)
     for cls in ast.iter_child_nodes(a):
         if isinstance(cls, ast.ClassDef) and is_test_class(cls):
             if cls.name in seen_classes:
                 raise ValueError("Duplicate class %s in %s" \
                                  % (cls.name, fname))
             seen_classes[cls.name] = {}
             seen_methods = {}
             for meth in ast.iter_child_nodes(cls):
                 if isinstance(meth, ast.FunctionDef) \
                    and meth.name.startswith('test'):
                     if meth.name in seen_methods:
                         raise ValueError("Duplicate method %s in %s" \
                                          % (meth.name, fname))
                     seen_methods[meth.name] = {}
                     testname = get_test_name(meth, fname, cls.name,
                                              meth.name,
                                              ast.get_docstring(meth, False))
                     out.append("%s.%s %s" % (cls.name, meth.name, testname))
     return out
Exemplo n.º 8
0
    def tag_class_functions(self, cls_node):
        """Tag functions if they are methods, classmethods, staticmethods"""
        # tries to find all 'old style decorators' like
        # m = staticmethod(m)
        late_decoration = {}
        for node in iter_child_nodes(cls_node):
            if not (isinstance(node, ast.Assign) and
                    isinstance(node.value, ast.Call) and
                    isinstance(node.value.func, ast.Name)):
                continue
            func_name = node.value.func.id
            if func_name in ('classmethod', 'staticmethod'):
                meth = (len(node.value.args) == 1 and node.value.args[0])
                if isinstance(meth, ast.Name):
                    late_decoration[meth.id] = func_name

        # iterate over all functions and tag them
        for node in iter_child_nodes(cls_node):
            if not isinstance(node, ast.FunctionDef):
                continue

            node.function_type = 'method'
            if node.name == '__new__':
                node.function_type = 'classmethod'

            if node.name in late_decoration:
                node.function_type = late_decoration[node.name]
            elif node.decorator_list:
                names = [d.id for d in node.decorator_list
                         if isinstance(d, ast.Name) and
                         d.id in ('classmethod', 'staticmethod')]
                if names:
                    node.function_type = names[0]
Exemplo n.º 9
0
    def visit_ClassDef(self, node):
        self.transforms = {}
        self.in_class_define = True

        functions_to_promote = []
        setup_func = None

        for class_func in ast.iter_child_nodes(node):
            if isinstance(class_func, ast.FunctionDef):
                if class_func.name == 'setup':
                    setup_func = class_func
                    for anon_func in ast.iter_child_nodes(class_func):
                        if isinstance(anon_func, ast.FunctionDef):
                            functions_to_promote.append(anon_func)

        if setup_func:
            for func in functions_to_promote:
                setup_func.body.remove(func)
                func.args.args.insert(0, ast.Name(id='self', ctx=ast.Load()))
                node.body.append(func)
                self.transforms[func.name] = 'self.' + func.name

            ast.fix_missing_locations(node)

        self.generic_visit(node)

        return node
Exemplo n.º 10
0
    def process_file(self):
        result = []
        with open(self.path) as f:
            current_node = ast.parse(f.read())
        pathes = os.path.splitext(self.path)[0]
        pathes = pathes.split("/")
        base_path = ".".join(pathes)

        nodes = [(i, base_path) for i in ast.iter_child_nodes(current_node)]
        while len(nodes) > 0:
            current_node, path = nodes.pop(0)
            if isinstance(current_node, ast.Import):
                module = ''
            elif isinstance(current_node, ast.ImportFrom):
                module = current_node.module
            elif isinstance(current_node, ast.FunctionDef) or isinstance(current_node, ast.ClassDef):
                path += "." + current_node.name
                next_nodes = [(i, path) for i in ast.iter_child_nodes(current_node)]
                nodes.extend(next_nodes)
                continue
            else:
                continue

            for n in current_node.names:
                result.append(self.node_class(module=module, full_path=path, name=n.name, alias=n.asname))
        return result
Exemplo n.º 11
0
 def visit_Attribute(self,node):
   if self.messagenode is not None:
     repnode = list(ast.walk(self.messagenode))
     if node in repnode:
       self.generic_visit(node)
       return
   self.messagenode = node
   t = ast.iter_child_nodes(node)
   res = [[t,0]]
   maxcount = 1
   while len(res) >= 1:
     t = res[-1][0]
     for childnode in t:
       # print childnode
       if isinstance(childnode, _ast.Attribute):
         res[-1][1] = 1
       else:
         res[-1][1] = 0
       res.append([ast.iter_child_nodes(childnode),0])
       break
     else:
       maxcount = max(maxcount,sum([flag for (item,flag) in res]) + 2)
       res.pop()
     continue
   # print maxcount
   if maxcount >= config['messagechain']:
     self.result.append((13,self.fileName,node.lineno,maxcount))
   self.generic_visit(node)
Exemplo n.º 12
0
 def test_iter_child_nodes(self):
     node = ast.parse("spam(23, 42, eggs='leek')", mode="eval")
     self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4)
     iterator = ast.iter_child_nodes(node.body)
     self.assertEqual(next(iterator).id, "spam")
     self.assertEqual(next(iterator).n, 23)
     self.assertEqual(next(iterator).n, 42)
     self.assertEqual(ast.dump(next(iterator)), "keyword(arg='eggs', value=Str(s='leek'))")
Exemplo n.º 13
0
def walklocal(root):
    """Recursively yield all descendant nodes but not in a different scope"""
    todo = collections.deque(ast.iter_child_nodes(root))
    yield root, False
    while todo:
        node = todo.popleft()
        newscope = isinstance(node, ast.FunctionDef)
        if not newscope:
            todo.extend(ast.iter_child_nodes(node))
        yield node, newscope
Exemplo n.º 14
0
 def test_iter_child_nodes(self):
     node = ast.parse("spam(23, 42, eggs='leek')", mode='eval')
     self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4)
     iterator = ast.iter_child_nodes(node.body)
     self.assertEqual(next(iterator).id, 'spam')
     self.assertEqual(next(iterator).c, 23)
     self.assertEqual(next(iterator).c, 42)
     self.assertEqual(ast.dump(next(iterator)),
         "keyword(arg='eggs', value=Const(c='leek', constant=pure_const()))"
     )
Exemplo n.º 15
0
    def find_global_defs(self, func_def_node):
        global_names = set()
        nodes_to_check = deque(iter_child_nodes(func_def_node))
        while nodes_to_check:
            node = nodes_to_check.pop()
            if isinstance(node, ast.Global):
                global_names.update(node.names)

            if not isinstance(node, (ast.FunctionDef, ast.ClassDef)):
                nodes_to_check.extend(iter_child_nodes(node))
        func_def_node.global_names = global_names
Exemplo n.º 16
0
 def visit_FunctionDef(self,node):
   # argsCount
   def findCharacter(s,d):
     try:
       value = s.index(d)
     except ValueError:
       return -1
     else:
       return value
   funcName = node.name.strip()
   p = re.compile("^(__[a-zA-Z0-9]+__)$")
   if p.match(funcName.strip()) and funcName != "__import__" and funcName != "__all__":
     self.defmagic.add((funcName,self.fileName,node.lineno))
   stmt = astunparse.unparse(node.args)
   arguments = stmt.split(",")
   argsCount = 0
   for element in arguments:
     if findCharacter(element,'=') == -1:
       argsCount += 1
   self.result.append((1,self.fileName,node.lineno,argsCount))
   #function length
   lines = set()
   res = [node]
   while len(res) >= 1:
     t = res[0]
     for n in ast.iter_child_nodes(t):
       if not hasattr(n,'lineno') or ((isinstance(t,_ast.FunctionDef) or isinstance(t,_ast.ClassDef)) and n == t.body[0] and isinstance(n,_ast.Expr)):
         continue
       lines.add(n.lineno)
       if isinstance(n,_ast.ClassDef) or isinstance(n,_ast.FunctionDef):
         continue
       else:
         res.append(n)
     del res[0]
   self.result.append((2,self.fileName,node.lineno,len(lines))) 
   #nested scope depth
   if node in self.scopenodes:
     self.scopenodes.remove(node)
     self.generic_visit(node)
     return
   dep = [[node,1]] #node,nestedlevel
   maxlevel = 1
   while len(dep) >= 1:
     t = dep[0][0]
     currentlevel = dep[0][1]
     for n in ast.iter_child_nodes(t):
       if isinstance(n,_ast.FunctionDef):
         self.scopenodes.append(n)
         dep.append([n,currentlevel+1])
     maxlevel = max(maxlevel,currentlevel)
     del dep[0]
   if maxlevel>1:
     self.result.append((3,self.fileName,node.lineno,maxlevel)) #DOC
   self.generic_visit(node) 
Exemplo n.º 17
0
 def generic_visit(self, node):
     name=typename(node)
     self.debug('debug; visiting: %s'%str(name))
     if name == 'Module':
         self.graph+='digraph G {'+self.delim
         ast.NodeVisitor.generic_visit(self, node)
         # TODO: print labels
         for n in self.nodenumbers.keys():
             self.graph+='%s [label="%s"];%s' % (self.getNodeNumber(n), self.getNodeType(n), self.delim)
         self.graph+='}'
     elif name == 'FunctionDef':
         # todo: number of params
         # always create downward links as needed
         for n in ast.iter_child_nodes(node):
             if typename(n) in ('For','If','Return'):
                 self.debug('%s -> %s' % (str(name),typename(n)))
                 self.makeLink(node, n)
         ast.NodeVisitor.generic_visit(self, node)
     elif name == 'For':
         for n in ast.iter_child_nodes(node):
             if typename(n) in ('For', 'If', 'Return'):
                 self.debug('%s -> %s' % (name,typename(n)))
                 self.makeLink(node, n)
         ast.NodeVisitor.generic_visit(self, node)
     elif name == 'If':
         # check the then branch
         found=False
         for n in node.body:
             if typename(n) in ('If','For','Return'):
                 found=True
                 self.makeLink(node,n)
         if not found:
             self.makeLink(node,node)
         if (len(node.orelse)>0):
             found=False
             for n in node.orelse:
                 if typename(n) in ('If','For','Return'):
                     found=True
                     self.makeLink(node,n)
             if not found:
                 self.graph += '%s -> %s;%s' % (self.getNodeNumber(node), self.getNodeNumber(node), self.delim)
         #print('body', node.body, typename(node.body), typename(node.orelse))
         for n in ast.iter_child_nodes(node):
             self.debug('%s ===> %s' % (str(name), typename(n)))
         ast.NodeVisitor.generic_visit(self, node)
     elif name == 'Return':
         # probably no need for recursive step here???
         #print(name, node._fields)
         ast.NodeVisitor.generic_visit(self, node)
     else:
         #print('CATCHALL', name)
         ast.NodeVisitor.generic_visit(self, node)
Exemplo n.º 18
0
def nodes_equal(x, y):
    __tracebackhide__ = True
    assert type(x) == type(y)
    if isinstance(x, (ast.Expr, ast.FunctionDef, ast.ClassDef)):
        assert x.lineno == y.lineno
        assert x.col_offset == y.col_offset
    for (xname, xval), (yname, yval) in zip(ast.iter_fields(x),
                                            ast.iter_fields(y)):
        assert xname == yname
        assert type(xval) == type(yval)
    for xchild, ychild in zip(ast.iter_child_nodes(x),
                              ast.iter_child_nodes(y)):
        assert nodes_equal(xchild, ychild)
    return True
Exemplo n.º 19
0
    def find_node(self, node, name=None, tipe=None):

        def name_match():
            if name is None:
                return True
            else:
                if 'name' in node.__dict__:
                    return name == node.__dict__['name']
                return False

        def type_match():
            if tipe is None:
                return True
            else:
                return isinstance(node, tipe)

        if type_match() and name_match():
            return node

        for child_node in ast.iter_child_nodes(node):
            found_node = self.find_node(child_node, name=name, tipe=tipe)
            if found_node is not None:
                return found_node

        return None
Exemplo n.º 20
0
def simplifyTree(tree):
    resultList = []
    for node in list(ast.iter_child_nodes(tree)):
        simplenode = {}
        vars(node).pop('lineno')
        vars(node).pop('col_offset')
        #print(vars(node))
        for key, value in vars(node).items():
            if(isinstance(value, ast.Num)):
                simplenode[key] = simplifyNum(value)
            elif(isinstance(value, ast.Name)):
                simplifyName(value)
            elif(isinstance(value, ast.Str)):
                simplenode[key] = simplifyStr(value)
            elif(key is 'targets'):
                simplenode[key] = simplifyName(value[0])
            elif(isinstance(value, ast.Call)):
                simplenode[key] = simplifyCall(value)
            elif(isinstance(value, ast.List)):
                simplenode[key] = simplifyList(value)
            elif(isinstance(value, ast.Dict)):
                simplifyDict(value)
            else:
                print(type(value))
        resultList.append(simplenode)
    return resultList
Exemplo n.º 21
0
 def _parse_file(self, fp, root):
     """Parse a file and return all normalized skbio imports."""
     imports = []
     with open(fp, 'U') as f:
         # Read the file and run it through AST
         source = ast.parse(f.read())
         # Get each top-level element, this is where API imports should be.
         for node in ast.iter_child_nodes(source):
             if isinstance(node, ast.Import):
                 # Standard imports are easy, just get the names from the
                 # ast.Alias list `node.names`
                 imports += [x.name for x in node.names]
             elif isinstance(node, ast.ImportFrom):
                 prefix = ""
                 # Relative import handling.
                 if node.level > 0:
                     prefix = root
                     extra = node.level - 1
                     while(extra > 0):
                         # Keep dropping...
                         prefix = os.path.split(prefix)[0]
                         extra -= 1
                     # We need this in '.' form not '/'
                     prefix = prefix.replace(os.sep, ".") + "."
                 # Prefix should be empty unless node.level > 0
                 imports += [".".join([prefix + node.module, x.name])
                             for x in node.names]
     skbio_imports = []
     for import_ in imports:
         # Filter by skbio
         if import_.split(".")[0] == "skbio":
             skbio_imports.append(import_)
     return skbio_imports
Exemplo n.º 22
0
 def generic_visit(self, node, inside_call=False, inside_yield=False):
     if isinstance(node, ast.Call):
         inside_call = True
     elif isinstance(node, ast.Yield):
         inside_yield = True
     elif isinstance(node, ast.Str):
         if disable_lookups:
             if inside_call and node.s.startswith("__"):
                 # calling things with a dunder is generally bad at this point...
                 raise AnsibleError(
                     "Invalid access found in the conditional: '%s'" % conditional
                 )
             elif inside_yield:
                 # we're inside a yield, so recursively parse and traverse the AST
                 # of the result to catch forbidden syntax from executing
                 parsed = ast.parse(node.s, mode='exec')
                 cnv = CleansingNodeVisitor()
                 cnv.visit(parsed)
     # iterate over all child nodes
     for child_node in ast.iter_child_nodes(node):
         self.generic_visit(
             child_node,
             inside_call=inside_call,
             inside_yield=inside_yield
         )
Exemplo n.º 23
0
 def _fix(node, lineno, col_offset):
     if "lineno" in node._attributes:
         node.lineno = lineno
     if "col_offset" in node._attributes:
         node.col_offset = col_offset
     for child in ast.iter_child_nodes(node):
         _fix(child, lineno, col_offset)
Exemplo n.º 24
0
        def wrapper(self, node):
            result = fn(self, node)

            for child in ast.iter_child_nodes(node):
                self.visit(child)

            return result
Exemplo n.º 25
0
 def visit_children(self, node):
     if isinstance(node, list):
         for node in node:
             self.visit_deps(node)
     else:
         for node in ast.iter_child_nodes(node):
             self.visit_deps(node)
Exemplo n.º 26
0
  def _extract_docstrings(self, bzl_file):
    """Extracts the docstrings for all public rules in the .bzl file.

    This function parses the .bzl file and extracts the docstrings for all
    public rules in the file that were extracted in _process_skylark. It calls
    _add_rule_doc for to parse the attribute documentation in each docstring
    and associate them with the extracted rules and attributes.

    Args:
      bzl_file: The .bzl file to extract docstrings from.
    """
    try:
      tree = None
      with open(bzl_file) as f:
        tree = ast.parse(f.read(), bzl_file)
      key = None
      for node in ast.iter_child_nodes(tree):
        if isinstance(node, ast.Assign):
          name = node.targets[0].id
          if not name.startswith("_"):
            key = name
          continue
        elif isinstance(node, ast.Expr) and key:
          # Python itself does not treat strings defined immediately after a
          # global variable definition as a docstring. Only extract string and
          # parse as docstring if it is defined.
          if hasattr(node.value, 's'):
            self._add_rule_doc(key, node.value.s.strip())
        key = None
    except IOError:
      print("Failed to parse {0}: {1}".format(bzl_file, e.strerror))
      pass
Exemplo n.º 27
0
def _get_ordered_child_nodes(node):
    if isinstance(node, ast.Dict):
        children = []
        for i in range(len(node.keys)):
            children.append(node.keys[i])
            children.append(node.values[i])
        return children
    elif isinstance(node, ast.Call):
        children = [node.func] + node.args

        for kw in node.keywords:
            children.append(kw.value)

        # TODO: take care of Python 3.5 updates (eg. args=[Starred] and keywords)
        if hasattr(node, "starargs") and node.starargs is not None:
            children.append(node.starargs)
        if hasattr(node, "kwargs") and node.kwargs is not None:
            children.append(node.kwargs)

        children.sort(key=lambda x: (x.lineno, x.col_offset))
        return children

    elif isinstance(node, ast.arguments):
        children = node.args + node.kwonlyargs + node.kw_defaults + node.defaults

        if node.vararg is not None:
            children.append(node.vararg)
        if node.kwarg is not None:
            children.append(node.kwarg)

        children.sort(key=lambda x: (x.lineno, x.col_offset))
        return children

    else:
        return ast.iter_child_nodes(node)
Exemplo n.º 28
0
 def rvals(node):
     if is_name(node):
         for k in targets:
             targets[k].add(node)
     else:
         for node in ast.iter_child_nodes(node):
             rvals(node)
Exemplo n.º 29
0
def find_attribute_with_name(node, name):
    if isinstance(node, ast.Attribute) and node.attr == name:
        return node
    for item in ast.iter_child_nodes(node):
        r = find_attribute_with_name(item, name)
        if r:
            return r
Exemplo n.º 30
0
    def _statement_assign(self, env, node):
        itr = self._fold_expr(env, node.value)
        while itr:
            try: yield next(itr)
            except StopIteration: break

        val = env.pop()
        # TODO what else is supposed to hold??
        assert len(node.targets) == 1
        target = node.targets[0]

        # fold the lists
        if isinstance(val, list):
            l = []
            for v in val:
                itr = self._fold_expr(env, v)
                while itr:
                    try: yield next(itr)
                    except StopIteration: break
                l.append(env.pop())
            val = l

        # assign to a list of variables
        if (isinstance(target, ast.List) or isinstance(target, ast.Tuple)) and isinstance(val, list):
            i = 0
            for var in ast.iter_child_nodes(target):
                if isinstance(var, ast.Name):
                    env.setvar(var.id, val[i])
                    i += 1
        else:
            env.setvar(target.id, val)
        yield node
Exemplo n.º 31
0
 def recurse(node: ast.Expression, temp: list):
     '''
     Local function for recursion over all leaves in the tree.
     Parameters:
         node - current node and all its children
         temp - array where the answer lives
     '''
     if isinstance(node, ast.BinOp):
         # go to the left part of expr
         recurse(node.left, temp)
         # go to the operation of the expr
         recurse(node.op, temp)
         # go to the right part of expr
         recurse(node.right, temp)
     # if it's the counter
     if isinstance(node, ast.Name):
         temp.append([node.id])
     # is other cases going though all children of the expresion
     else:
         for child in ast.iter_child_nodes(node):
             recurse(child, temp)
             return temp
Exemplo n.º 32
0
    def __get_imports(self, file_path):
        """
        Extract all import statements from the provided file including aliases.
        :param path: file to be analyzed
        :return: list of `ImportedModule`
        """
        # get the AST of the provided file
        with open(file_path) as file:
            root = ast.parse(file.read(), file_path)

        # extract import statements
        for node in ast.iter_child_nodes(root):
            if isinstance(node, ast.Import):
                module = ''
            elif isinstance(node, ast.ImportFrom):
                module = node.module
            else:
                continue

            for n in node.names:
                yield ImportedModule(module, n.name, n.asname, node.lineno,
                                     node.col_offset, None)
Exemplo n.º 33
0
def get_contrib_requirements(filepath: str) -> Dict:
    """
    Parse the python file from filepath to identify a "library_metadata" dictionary in any defined classes, and return a requirements_info object that includes a list of pip-installable requirements for each class that defines them.

    Note, currently we are handling all dependencies at the module level. To support future expandability and detail, this method also returns per-class requirements in addition to the concatenated list.

    Args:
        filepath: the path to the file to parse and analyze

    Returns:
        A dictionary:
        {
            "requirements": [ all_requirements_found_in_any_library_metadata_in_file ],
            class_name: [ requirements ]
        }

    """
    with open(filepath) as file:
        tree = ast.parse(file.read())

    requirements_info = {"requirements": []}
    for child in ast.iter_child_nodes(tree):
        if not isinstance(child, ast.ClassDef):
            continue
        current_class = child.name
        for node in ast.walk(child):
            if isinstance(node, ast.Assign):
                try:
                    target_ids = [target.id for target in node.targets]
                except (ValueError, AttributeError):
                    # some assignment types assign to non-node objects (e.g. Tuple)
                    target_ids = []
                if "library_metadata" in target_ids:
                    library_metadata = ast.literal_eval(node.value)
                    requirements = library_metadata.get("requirements", [])
                    requirements_info[current_class] = requirements
                    requirements_info["requirements"] += requirements

    return requirements_info
Exemplo n.º 34
0
def class_parse(args):
    if args.cl:
        objsearch = 'ClassDef'
    elif args.func:
        objsearch = 'FunctionDef'
    with open(args.path) as f:
        content = f.read()
    results = []
    added_lines = []
    root = ast.parse(content)
    for node in ast.walk(root):
        for child in ast.iter_child_nodes(node):
            child.parent = node
    # search for pattern
    for num, line in enumerate(content.splitlines(), 1):
        if (args.pattern in line or
            (args.regexp and re.search(args.pattern, line))) and (
                num, line) not in added_lines:
            pattern_node = find_match_node(results, num, root, args)
            if pattern_node is None:
                continue
            else:
                while objsearch not in str(pattern_node):
                    if pattern_node.parent is root:
                        break
                    pattern_node = pattern_node.parent
            curres = []
            if objsearch in str(pattern_node):
                first = pattern_node.lineno
                end = get_end(pattern_node)
                curres += [
                    mhighlight(num, line, args.pattern, args.regexp)
                    for num, line in enumerate(
                        content.splitlines()[first - 1:end], first)
                ]
                added_lines += [(num, line) for num, line in enumerate(
                    content.splitlines()[first - 1:end], first)]
            results.append(''.join(curres))
    return results
Exemplo n.º 35
0
def _iter_child_nodes_in_order_internal_1(node):
    if not isinstance(node, ast.AST):
        raise TypeError
    if isinstance(node, ast.Dict):
        assert node._fields == ("keys", "values")
        yield list(zip(node.keys, node.values))
    elif isinstance(node, ast.FunctionDef):
        if six.PY2:
            assert node._fields == ('name', 'args', 'body', 'decorator_list')
        else:
            assert node._fields == ('name', 'args', 'body', 'decorator_list',
                                    'returns')
        yield node.decorator_list, node.args, node.body
        # node.name is a string, not an AST node
    elif isinstance(node, ast.arguments):
        if six.PY2:
            assert node._fields == ('args', 'vararg', 'kwarg', 'defaults')
        else:
            assert node._fields == ('args', 'vararg', 'kwonlyargs',
                                    'kw_defaults', 'kwarg', 'defaults')
        defaults = node.defaults or ()
        num_no_default = len(node.args) - len(defaults)
        yield node.args[:num_no_default]
        yield list(zip(node.args[num_no_default:], defaults))
        # node.varags and node.kwarg are strings, not AST nodes.
    elif isinstance(node, ast.IfExp):
        assert node._fields == ('test', 'body', 'orelse')
        yield node.body, node.test, node.orelse
    elif isinstance(node, ast.ClassDef):
        if six.PY2:
            assert node._fields == ('name', 'bases', 'body', 'decorator_list')
        else:
            assert node._fields == ('name', 'bases', 'keywords', 'body',
                                    'decorator_list')
        yield node.decorator_list, node.bases, node.body
        # node.name is a string, not an AST node
    else:
        # Default behavior.
        yield ast.iter_child_nodes(node)
Exemplo n.º 36
0
 def _get_ast_trees(self, repository_path):
     filenames = []
     main_file_contents = []
     ast_trees = []
     for dirname, _, files in os.walk(repository_path, topdown=True):
         for file in files:
             if file.endswith('.py'):
                 filenames.append(os.path.join(dirname, file))
     for filename in filenames:
         with open(filename, 'r', encoding='utf-8') as file_handler:
             main_file_content = file_handler.read()
         try:
             tree = ast.parse(main_file_content)
             for node in ast.walk(tree):
                 for child in ast.iter_child_nodes(node):
                     child.parent = node
         except SyntaxError as e:
             print(e)
             tree = None
         main_file_contents.append(main_file_content)
         ast_trees.append(tree)
     return filenames, main_file_contents, ast_trees
Exemplo n.º 37
0
    def visit_While(self, node):
        if node and not config.mutated:
            for child in ast.iter_child_nodes(node):
                config.parent_dict[child] = node

            if self.operator[1] is StatementDeletion:
                for anode in node.body:
                    if anode.__class__ in [
                            ast.Raise, ast.Continue, ast.Break, ast.Assign,
                            ast.AugAssign, ast.Call
                    ]:
                        config.nodes_to_remove.add(anode)
                    elif anode.__class__ in [ast.Expr]:
                        config.nodes_to_potential.add(anode)
                node = self.mutate_single_node(node, self.operator)
            else:
                node = self.mutate_single_node(node, self.operator)
            if node and not config.mutated:
                self.dfs_visit(node)
        elif node and config.mutated and config.recovering:
            return self.recover_node(node)
        return node
Exemplo n.º 38
0
def checkGlobalIds(a, l):
    if not isinstance(a, ast.AST):
        return
    elif type(a) in [
            ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param
    ]:
        return
    if not hasattr(a, "global_id"):
        addedNodes = [
            "propagatedVariable",
            "orderedBinOp",
            "augAssignVal",
            "augAssignBinOp",
            "combinedConditional",
            "combinedConditionalOp",
            "multiCompPart",
            "multiCompOp",
            "second_global_id",
            "moved_line",
            # above this line has individualize functions. below does not.
            "addedNot",
            "addedNotOp",
            "addedOther",
            "addedOtherOp",
            "collapsedExpr",
            "removedLines",
            "helperVar",
            "helperReturn",
            "typeCastFunction",
        ]
        for t in addedNodes:
            if hasattr(a, t):
                break
        else:  # only enter the else if none of the provided types are an attribute of a
            log(
                "canonicalize\tcheckGlobalIds\tNo global id: " + str(l) + "," +
                str(a.__dict__) + "," + printFunction(a, 0), "bug")
    for f in ast.iter_child_nodes(a):
        checkGlobalIds(f, l + [type(a)])
Exemplo n.º 39
0
def process_child_nodes(
    node: ast.AST,
    increment_by: int,
    verbose: bool,
    complexity_calculator: Callable,
) -> int:
    child_complexity = 0
    child_nodes = ast.iter_child_nodes(node)

    for node_num, child_node in enumerate(child_nodes):
        if isinstance(node, ast.Try):
            if node_num == 1:
                # add +1 for all try nodes except body
                increment_by += 1
            if node_num:
                child_complexity += max(1, increment_by)
        child_complexity += complexity_calculator(
            child_node,
            increment_by=increment_by,
            verbose=verbose,
        )
    return child_complexity
Exemplo n.º 40
0
    def visit_FunctionDef(self, node):
        # self._filename is 'stdin' in the unit test for this check.
        if (not os.path.basename(self._filename).startswith('test_') and
            not 'stdin'):
            return

        closures = []
        references = []
        # Walk just the direct nodes of the test method
        for child_node in ast.iter_child_nodes(node):
            if isinstance(child_node, ast.FunctionDef):
                closures.append(child_node.name)

        # Walk all nodes to find references
        find_references = _FindVariableReferences()
        find_references.generic_visit(node)
        references = find_references._references

        missed = set(closures) - set(references)
        if missed:
            self.add_error(node, 'N349: Test closures not called: %s'
                    % ','.join(missed))
Exemplo n.º 41
0
def _prove_names_defined(
        env: SymbolTable,
        names: tp.AbstractSet[str],
        node: tp.Union[ast.AST, tp.Sequence[ast.AST]]) -> tp.AbstractSet[str]:
    '''
    Prove that all names are defined.
    i.e. if a name is used at some point then all paths leading to that point
    must either return or define the name.
    '''
    names = set(names)
    if isinstance(node, ast.Name):
        if isinstance(node.ctx, ast.Store):
            names.add(node.id)
        elif node.id not in names and node.id not in env and \
                node.id not in env["__builtins__"].__dict__:
            if hasattr(node, 'lineno'):
                raise SyntaxError(f'Cannot prove name, {node.id}, is defined at line {node.lineno}')
            else:
                raise SyntaxError(f'Cannot prove name, {node.id}, is defined')

    elif isinstance(node, ast.If):
        t_returns = _always_returns(node.body)
        f_returns = _always_returns(node.orelse)
        t_names = _prove_names_defined(env, names, node.body)
        f_names = _prove_names_defined(env, names, node.orelse)
        if not (t_returns or f_returns):
            names |= t_names & f_names
        elif t_returns:
            names |= f_names
        elif f_returns:
            names |= t_names
    elif isinstance(node, ast.AST):
        for child in ast.iter_child_nodes(node):
            names |= _prove_names_defined(env, names, child)
    else:
        assert isinstance(node, tp.Sequence)
        for child in node:
            names |= _prove_names_defined(env, names, child)
    return names
Exemplo n.º 42
0
def _move_var_decls_to_top_of_scope(node, vars_already_scoped={}):
    """Initilizes all variables used in a scope to None
    at the begining of that scope.
    Parameters:
        node: ast node
        vars_already_scoped: hash table of variables that have already been
            reinitilized
    Returns:
        node with variables initialized at top of scope.
        vars_already_scoped: hashtable with variables already moved."""

    current_scope_new_vars = []

    if isinstance(node, ast.FunctionDef) or isinstance(node, ast.Module):
        for child_node in ast.iter_child_nodes(node):
            new_vars, vars_already_scoped = _find_vars_in_same_scope(
                                            child_node, vars_already_scoped)
            current_scope_new_vars = current_scope_new_vars + new_vars
        current_scope_new_vars.reverse()
        _move_initilizations_to_top_of_scope(node, current_scope_new_vars)

    return node, vars_already_scoped
Exemplo n.º 43
0
def get_test_methods(test_path: str) -> List[Any]:
    """Gets the top-level methods within a test file

    Args:
        source_path: path to the file to process

    Returns:
        List[ast.AST]: a list of the top-level
                       methods within the provided file
    """
    try:
        with open(test_path, 'r') as file:
            content = ''.join(file.readlines())
            parsed_nodes = list(ast.iter_child_nodes(ast.parse(content)))

            test_nodes = _get_test_nodes(parsed_nodes)

            for node in test_nodes:
                node.test_path = os.path.abspath(test_path)

            # Verify file contains no duplicate method names
            # (Only relevant for test methods wrapped in classes)
            used_test_names = set()
            for node in test_nodes:
                if node.name in used_test_names:
                    raise ValueError(f'Test name {node.name} in file'
                                     f' {test_path} must be unique.')
                used_test_names.add(node.name)

            return test_nodes
    except IOError as err:
        # Fail gracefully if a file can't be read
        # (This shouldn't happen, but if it doess
        #  we don't want to 'break the build'.)
        sys.stderr.write(f'WARNING: could not read file: {test_path}\n')
        sys.stderr.write(f'\t{str(err)}\n')

        return []
Exemplo n.º 44
0
	def visit_Name(self, node):
		self.currentAttribute.append(node.id)
		fullyQualified = '.'.join(reversed(self.currentAttribute))
		self._nnPrint( node, '%s --> %s' % (node.id, fullyQualified) )
		
		if not self.checkScopeAllowed(reversed(self.currentAttribute)):	
			if '.' in fullyQualified \
				or fullyQualified in self.TRUSTED_BUILTINS:
				
				self.disallowed.append(fullyQualified)
				
		
			# allow assignments in iterators
			elif self.iterNodeChildren:
				if node in self.iterNodeChildren:
					self.iterNodeChildren.remove(node)
					
					# naive pass - just accept the name
					self.LOCAL_VARIABLES.add(node.id)
				
		
		
			# strictly allow assignments that are single-label			
			elif self.assignmentNodeChildren and node in self.assignmentNodeChildren \
				and	any([isinstance(child, ast.Store) for  child in ast.iter_child_nodes(node) ]) :
				
				self.LOCAL_VARIABLES.add(fullyQualified)
				
			else:
				self.disallowed.append(fullyQualified)
		else:
			rootObj = self.currentAttribute[-1]
			if rootObj in Validator.TRUSTED_IMPORTS:
				self.requiredImports.add(rootObj)
		
		# clear for the next
		self.currentAttribute = []
		self.assignmentNodeChildren = None
    def lstm_unit(self, ast_node, depth=0):
        """
        Process of one LSTM unit.
        Recursively calls learning processes on all children in one tree

        : param ast_node: one Python AST node; First call will be with root Node
        : returns: hidden state and context of node; eventually for the whole AST
        """
        weight = torch.tensor([])  # TODO weights with lstm calculation!!
        w_t = ast2vec(ast_node, self.dictionary,
                      self.emb_matrix)  # embedding of tree
        # sum of children hidden outputs
        h_ = 0
        # child hidden state
        h_k = 0
        # context of child
        c_k = 0
        # forget gates
        f_tk = 0
        # childrem forgetrates times the context
        c_ = 0
        for k in ast.iter_child_nodes(ast_node):
            print(k, depth)
            h_k, c_k = self.lstm_unit(k, depth + 1)
            f_tk = torch.nn.Sigmoid()(weight)
            h_ += h_k
            c_ += (f_tk * c_k)
        # input gate
        i_t = torch.nn.Sigmoid()(weight)
        # vector of new candidate values for t
        c_t_ = torch.nn.Tanh()(weight)
        # context
        c_t = i_t * c_t_ + c_
        # output gate
        o_t = torch.nn.Sigmoid()(weight)
        h_t = o_t * torch.nn.Tanh()(c_t)

        return h_t, c_t
Exemplo n.º 46
0
 def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
     if 'lineno' in node._attributes:
         if not hasattr(node, 'lineno'):
             node.lineno = lineno
         else:
             lineno = node.lineno
     if 'end_lineno' in node._attributes:
         if not hasattr(node, 'end_lineno'):
             node.end_lineno = end_lineno
         else:
             end_lineno = node.end_lineno
     if 'col_offset' in node._attributes:
         if not hasattr(node, 'col_offset'):
             node.col_offset = col_offset
         else:
             col_offset = node.col_offset
     if 'end_col_offset' in node._attributes:
         if not hasattr(node, 'end_col_offset'):
             node.end_col_offset = end_col_offset
         else:
             end_col_offset = node.end_col_offset
     for child in iter_child_nodes(node):
         _fix(child, lineno, col_offset, end_lineno, end_col_offset)
Exemplo n.º 47
0
def _get_ordered_child_nodes(node):
    if isinstance(node, ast.Dict):
        children = []
        for i in range(len(node.keys)):
            children.append(node.keys[i])
            children.append(node.values[i])
        return children
    elif isinstance(node, ast.Call):
        children = [node.func] + node.args

        for kw in node.keywords:
            children.append(kw.value)

        # TODO: take care of Python 3.5 updates (eg. args=[Starred] and keywords)
        if hasattr(node, "starargs") and node.starargs is not None:
            children.append(node.starargs)
        if hasattr(node, "kwargs") and node.kwargs is not None:
            children.append(node.kwargs)

        children.sort(key=lambda x: (x.lineno, x.col_offset))
        return children
    else:
        return ast.iter_child_nodes(node)
Exemplo n.º 48
0
def check_module(
    module: ast.Module, ignore_ambiguous_signatures: bool = True
) -> Iterator[Tuple[ast.AST, List[str], List[str]]]:
    """Check a module.

    Parameters
    ----------
    module : ast.Module
        The module in which to check functions and classes.
    ignore_ambiguous_signatures : bool, optional
        Whether to ignore extra documented arguments if the function as an
        ambiguous (*args / **kwargs) signature (the default is True).

    Returns
    -------
    dict
    """

    for node in ast.iter_child_nodes(module):
        if not is_private(node):
            check_result = check(node, ignore_ambiguous_signatures)
            if check_result is not None:
                yield from check_result
Exemplo n.º 49
0
def find_lca(root, node1, node2):
    if root == None: return None
    if root == node1 or root == node2: return root

    # look for lca in all subtrees
    lca_list = []
    for child_node in ast.iter_child_nodes(root):
        if child_node not in IGNORE_TYPES:
            lca_list.append(find_lca(child_node, node1, node2))

    # Remove None in the list
    lca_list = [i for i in lca_list if i != None]
    # print('for node:', type(root).__name__)
    # print(lca_list)

    # If two of the above calls returns non-None, then two nodes are present in separate subtree
    # if len(lca_list) == 2:
    if len(lca_list) >= 2:
        return root
    # If above calls return one non-None, then two nodes exit in one subtree
    if len(lca_list) == 1:
        return lca_list[0]
    return None
Exemplo n.º 50
0
def walk(node, stop_at=tuple(), ignore=tuple()):
    """Walk through the children of an ast node.

    Args:
        node: an ast node
        stop_at: stop traversing through these nodes, including the matching
            node
        ignore: stop traversing through these nodes, excluding the matching
            node

    Returns: a generator of ast nodes
    """
    todo = deque([node])
    while todo:
        node = todo.popleft()
        if isinstance(node, ignore):
            # dequeue next node
            continue
        if not isinstance(node, stop_at):
            next_nodes = ast.iter_child_nodes(node)
            for n in next_nodes:
                todo.extend([n])
        yield node
Exemplo n.º 51
0
def createNameMap(a, d=None):
	if d == None:
		d = { }
	if not isinstance(a, ast.AST):
		return d
	if type(a) == ast.Module: # Need to go through the functions backwards to make this right
		for i in range(len(a.body) - 1, -1, -1):
			createNameMap(a.body[i], d)
		return d
	if type(a) in [ast.FunctionDef, ast.ClassDef]:
		if hasattr(a, "originalId") and a.name not in d:
			d[a.name] = a.originalId
	elif type(a) == ast.arg:
		if hasattr(a, "originalId") and a.arg not in d:
			d[a.arg] = a.originalId
		return d
	elif type(a) == ast.Name:
		if hasattr(a, "originalId") and a.id not in d:
			d[a.id] = a.originalId
		return d
	for child in ast.iter_child_nodes(a):
		createNameMap(child, d)
	return d
Exemplo n.º 52
0
def get_imports(root: ast.AST):
    # https://stackoverflow.com/a/9049549
    Import = collections.namedtuple("Import", ["module", "name", "alias"])

    result = []

    for node in ast.iter_child_nodes(root):
        if isinstance(node, ast.Import):
            module = []
        elif isinstance(node, ast.ImportFrom):
            module = node.module.split('.')
        elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
            function_result = get_imports(node)
            if function_result:
                result += function_result
            continue
        else:
            continue

        for n in node.names:
            result.append(Import(module, n.name.split('.'), n.asname))

    return result
Exemplo n.º 53
0
def _relevant_statements(ancestors: List[ast.AST]) -> Iterable[ast.AST]:
    """
    Given a list of ancestors, finds the statements to analyze.

    Given a list of ancestors, as described by :func:`_ast_ancestors`, finds
    all statements in the tree that are considered "worthy to be analyzed".
    In general, these are all child nodes of any ancestor that represent
    import statements or assign statements.

    :param ancestors: The list of ancestors, as described by :func:`_ast_ancestors`.
    :return: An iteration over the relevant nodes.
    """
    acceptable_types = [ast.Import, ast.ImportFrom, ast.Assign]
    for ancestor in ancestors:
        for child in ast.iter_child_nodes(ancestor):
            if any(isinstance(child, t) for t in acceptable_types):
                yield child
            elif isinstance(child, ast.Try):
                for grandchild in itertools.chain(
                    child.body, child.finalbody, child.orelse
                ):
                    if any(isinstance(grandchild, t) for t in acceptable_types):
                        yield grandchild
Exemplo n.º 54
0
def get_meta(filename):
    """Get top level module metadata without execution.
    """
    result = {
        '__file__': filename,
        '__name__': os.path.splitext(os.path.basename(filename))[0],
        '__package__': '',
    }

    with open(filename) as fp:
        root = ast.parse(fp.read(), fp.name)

    result['__doc__'] = ast.get_docstring(root)

    for node in ast.iter_child_nodes(root):
        if isinstance(node, ast.Assign):
            for target in node.targets:
                try:
                    result[target.id] = ast.literal_eval(node.value)
                except ValueError:
                    pass

    return result
Exemplo n.º 55
0
def get_name_import_path(name_node: ast.Name,
                         pyfilepath: str) -> Optional[EntityImportInfo]:
    current_node = name_node.parent  # type: ignore
    while True:
        for child in ast.iter_child_nodes(current_node):
            # check for Import, not only ImportFrom
            if isinstance(
                    child,
                    ast.ImportFrom) and name_node.id in (n.name
                                                         for n in child.names):
                return extract_import_info_from_import_node(child, name_node)
            elif (isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
                  and child.name == name_node.id):
                return {
                    'file_path':
                    pyfilepath if pyfilepath != 'built-in' else None,
                    'import_path': None,
                    'name': name_node.id,
                }
        current_node = current_node.parent
        if not current_node:
            break
    return None
Exemplo n.º 56
0
def parse_flags(synth_py_path="synth.py") -> typing.Dict[str, typing.Any]:
    """Extracts flags from a synth.py file.

    Keyword Arguments:
        synth_py_path {str or Path} -- Path to synth.py (default: {"synth.py"})

    Returns:
        typing.Dict[str, typing.Any] -- A map of all possible flags.  Flags not
          found in synth.py will be set to their default value.
    """
    flags = copy.copy(_SYNTH_PY_FLAGS)
    path = pathlib.Path(synth_py_path)
    try:
        module = ast.parse(path.read_text())
    except SyntaxError:
        return flags  # Later attempt to execute synth.py will give a clearer error message.

    for child in ast.iter_child_nodes(module):
        if isinstance(child, ast.Assign):
            for target in child.targets:
                if isinstance(target, ast.Name) and target.id in flags:
                    flags[target.id] = ast.literal_eval(child.value)
    return flags
Exemplo n.º 57
0
    def embed_subtree(self, node):
        children_embeddings = []
        # c_0 = torch.zeros(1, self.embedding_dim)
        # h_0 = torch.zeros(1, self.embedding_dim)
        # n_children = 0
        # print("Embedding the subtree of ", node)
        # print("\t\t\t", list(ast.iter_child_nodes(node)))
        for child in ast.iter_child_nodes(node):
            # print(child)
            child_embedding = self.embed_node(child).view(1, -1)
            children_embeddings.append(child_embedding)
            # (c_0, h_0) = self.subtree_network(child_embedding, (c_0, h_0))
            # h_0 += child_embedding
            # n_children += 1

        # h_0 /= n_children
        # indices = torch.from_numpy()
        # embeddings = self.embedding_layer(indices)

        # children_embeddings = torch.cat(children_embeddings, 0)[None, :, :]
        # print(children_embeddings)

        return children_embeddings
Exemplo n.º 58
0
def exec_eval(script, globals=None, locals=None, name=''):
    '''Execute a script and return the value of the last expression'''
    stmts = list(ast.iter_child_nodes(ast.parse(script)))
    if not stmts:
        return None
    if isinstance(stmts[-1], ast.Expr):
        # the last one is an expression and we will try to return the results
        # so we first execute the previous statements
        if len(stmts) > 1:
            if sys.version_info >= (3, 8):
                mod = ast.Module(stmts[:-1], [])
            else:
                mod = ast.Module(stmts[:-1])
            exec(compile(mod, filename=name, mode="exec"), globals, locals)
        # then we eval the last one
        return eval(
            compile(ast.Expression(body=stmts[-1].value),
                    filename=name,
                    mode="eval"), globals, locals)
    else:
        # otherwise we just execute the entire code
        return exec(compile(script, filename=name, mode="exec"), globals,
                    locals)
Exemplo n.º 59
0
def get_info(a, depth=0):
    "Print detailed information about an AST"
    nm = a.__class__.__name__
    print("  " * depth, end="")
    iter_children = True
    if nm == "Num":
        if type(a.n) == int:
            print("%s=%d" % (nm, a.n))
        else:
            print("%s=%f" % (nm, a.n))
    elif nm == "Global":
        print("Global:", dir(a))
    elif nm == "Str":
        print("%s='%s'" % (nm, a.s))
    elif nm == "Name":
        print("%s='%s'" % (nm, a.id))
    elif nm == "arg":
        print("%s='%s'" % (nm, a.arg))
    elif nm == "If":
        iter_children = False
        print(nm)
        get_info(a.test, depth)
        for n in a.body:
            get_info(n, depth + 1)
        if len(a.orelse) > 0:
            print("  " * depth, end="")
            print("Else")
            for n in a.orelse:
                get_info(n, depth + 1)
    else:
        print(nm)
    for (f, v) in ast.iter_fields(a):
        if type(f) == str and type(v) == str:
            print("%s:attr[%s]=%s" % ("  " * (depth + 1), f, v))
    if iter_children:
        for n in ast.iter_child_nodes(a):
            get_info(n, depth + 1)
def set_node_context(tree: ast.AST) -> ast.AST:
    """
    Used to set proper context to all nodes.

    What we call "a context"?
    Context is where exactly this node belongs on a global level.

    Example::

        if some_value > 2:
            test = 'passed'

    Despite the fact ``test`` variable has ``Assign`` as it parent
    it will have ``Module`` as a context.

    What contexts do we respect?

    - :py:class:`ast.Module`
    - :py:class:`ast.ClassDef`
    - :py:class:`ast.FunctionDef` and :py:class:`ast.AsyncFunctionDef`

    """
    contexts = (
        ast.Module,
        ast.ClassDef,
        ast.FunctionDef,
        ast.AsyncFunctionDef,
    )

    current_context = None
    for statement in ast.walk(tree):
        if isinstance(statement, contexts):
            current_context = statement

        for child in ast.iter_child_nodes(statement):
            setattr(child, 'context', current_context)
    return tree