示例#1
0
    def prepare(self, node):
        assert isinstance(node, ast.Module)
        self.env = {
            'builtins': __import__('builtins'),
        }

        if sys.implementation.name == 'pypy':
            self.env['__builtins__'] = self.env['builtins']

        for module_name in MODULES:
            # __dispatch__ is the only fake top-level module
            if module_name != '__dispatch__':
                alias_module_name = mangle(module_name)
                try:
                    self.env[alias_module_name] = __import__(module_name)
                except ImportError:
                    pass

        # we need to parse the whole code to be able to apply user-defined pure
        # function but import are resolved before so we remove them to avoid
        # ImportError (for operator_ for example)
        dummy_module = ast.Module([s for s in node.body
                                   if not isinstance(s, ast.Import)],
                                  [])
        ast.fix_missing_locations(dummy_module)
        eval(compile(ast.gast_to_ast(dummy_module),
                     '<constant_folding>', 'exec'),
             self.env)

        super(ConstantFolding, self).prepare(node)
示例#2
0
def to_static_ast(node, class_node):
    assert isinstance(node, gast.Call)
    assert isinstance(class_node, gast.Call)
    static_api = to_static_api(class_node.func.attr)

    node.func = gast.Attribute(attr=static_api,
                               ctx=gast.Load(),
                               value=gast.Attribute(attr='layers',
                                                    ctx=gast.Load(),
                                                    value=gast.Name(
                                                        ctx=gast.Load(),
                                                        id='fluid',
                                                        annotation=None,
                                                        type_comment=None)))

    update_args_of_func(node, class_node, 'forward')

    node.args.extend(class_node.args)
    node.keywords.extend(class_node.keywords)
    _add_keywords_to(node, class_node.func.attr)
    _delete_keywords_from(node)

    gast.fix_missing_locations(node)

    return node
示例#3
0
 def run(self, node):
     """ Apply transformation and dependencies and fix new node location."""
     n = super(Transformation, self).run(node)
     if self.update:
         ast.fix_missing_locations(n)
         self.passmanager._cache.clear()
     return n
示例#4
0
 def run(self, node):
     """ Apply transformation and dependencies and fix new node location."""
     n = super(Transformation, self).run(node)
     if self.update:
         ast.fix_missing_locations(n)
         self.passmanager._cache.clear()
     return n
示例#5
0
 def test_fix_missing_locations(self):
     node = gast.Constant(value=6, kind=None)
     tree = gast.UnaryOp(gast.USub(), node)
     tree.lineno = 1
     tree.col_offset = 2
     gast.fix_missing_locations(tree)
     self.assertEqual(node.lineno, tree.lineno)
     self.assertEqual(node.col_offset, tree.col_offset)
示例#6
0
 def test_fix_missing_locations(self):
     node = gast.Num(n=6)
     tree = gast.UnaryOp(gast.USub(), node)
     tree.lineno = 1
     tree.col_offset = 2
     gast.fix_missing_locations(tree)
     self.assertEqual(node.lineno, tree.lineno)
     self.assertEqual(node.col_offset, tree.col_offset)
示例#7
0
 def run(self, node):
     """ Apply transformation and dependencies and fix new node location."""
     n = super(Transformation, self).run(node)
     # the transformation updated the AST, so analyse may need to be rerun
     # we could use a finer-grain caching system, and provide a way to flag
     # some analyses as `unmodified' by the transformation, as done in LLVM
     # (and PIPS ;-)
     if self.update:
         ast.fix_missing_locations(n)
         self.passmanager._cache.clear()
     return n
示例#8
0
    def visit_Module(self, node):
        """
        Visit the whole module and add all import at the top level.

        >> import numpy.linalg

        Becomes

        >> import numpy

        """
        node.body = [k for k in (self.visit(n) for n in node.body) if k]
        imports = [ast.Import([ast.alias(i, mangle(i))]) for i in self.imports]
        node.body = imports + node.body
        ast.fix_missing_locations(node)
        return node
示例#9
0
    def visit_Module(self, node):
        """
        Visit the whole module and add all import at the top level.

        >> import numpy.linalg

        Becomes

        >> import numpy

        """
        node.body = [k for k in (self.visit(n) for n in node.body) if k]
        imports = [ast.Import([ast.alias(i, None)]) for i in self.imports]
        node.body = imports + node.body
        ast.fix_missing_locations(node)
        return node
示例#10
0
    def visit(self, node: AST) -> AST:
        # recursively visit child nodes
        super().visit(node)
        # on visit: transform node and fix code locations
        new_node = gast.copy_location(new_node=self.transform_fn(node), old_node=node)
        new_node = gast.fix_missing_locations(new_node)

        return new_node
示例#11
0
    def visit_Compare(self, node):
        """ Boolean are possible index.

        >>> import gast as ast
        >>> from pythran import passmanager, backend
        >>> node = ast.parse('''
        ... def foo():
        ...     a = 2 or 3
        ...     b = 4 or 5
        ...     c = a < b
        ...     d = b < 3
        ...     e = b == 4''')
        >>> pm = passmanager.PassManager("test")
        >>> res = pm.gather(RangeValues, node)
        >>> res['c']
        Interval(low=1, high=1)
        >>> res['d']
        Interval(low=0, high=0)
        >>> res['e']
        Interval(low=0, high=1)
        """
        if any(
                isinstance(op, (ast.In, ast.NotIn, ast.Is, ast.IsNot))
                for op in node.ops):
            self.generic_visit(node)
            return self.add(node, Interval(0, 1))

        curr = self.visit(node.left)
        res = []
        for op, comparator in zip(node.ops, node.comparators):
            comparator = self.visit(comparator)
            fake = ast.Compare(ast.Name('x', ast.Load(), None), [op],
                               [ast.Name('y', ast.Load(), None)])
            fake = ast.Expression(fake)
            ast.fix_missing_locations(fake)
            expr = compile(ast.gast_to_ast(fake), '<range_values>', 'eval')
            res.append(eval(expr, {'x': curr, 'y': comparator}))
        if all(res):
            return self.add(node, Interval(1, 1))
        elif any(r.low == r.high == 0 for r in res):
            return self.add(node, Interval(0, 0))
        else:
            return self.add(node, Interval(0, 1))
示例#12
0
def anf_function(f, globals_=None):
    m = gast.gast_to_ast(anf.anf(quoting.parse_function(f)))
    m = gast.fix_missing_locations(m)
    exec(compile(m, '<string>', 'exec'), globals_)
    return f
示例#13
0
 def run(self, node, ctx):
     """ Apply transformation and dependencies and fix new node location."""
     n = super(Transformation, self).run(node, ctx)
     ast.fix_missing_locations(n)
     return n
示例#14
0
 def run(self, node, ctx):
     """ Apply transformation and dependencies and fix new node location."""
     n = super(Transformation, self).run(node, ctx)
     if self.update:
         ast.fix_missing_locations(n)
     return n
示例#15
0
    def visit_For(self, node):
        # if the user added some OpenMP directive, trust him and no unroll
        if metadata.get(node, OMPDirective):
            return node  # don't visit children because of collapse

        # first unroll children if needed or possible
        self.generic_visit(node)

        # a break or continue in the loop prevents unrolling too
        has_break = any(self.gather(HasBreak, n) for n in node.body)
        has_cont = any(self.gather(HasContinue, n) for n in node.body)

        if has_break or has_cont:
            return node

        # do not unroll too much to prevent code growth
        node_count = self.gather(NodeCount, node)

        def unroll(elt, body):
            return [ast.Assign([deepcopy(node.target)], elt, None)] + body

        def dc(body, i, n):
            if i == n - 1:
                return body
            else:
                return deepcopy(body)

        def getrange(n):
            return getattr(getattr(n, 'func', None), 'attr', None)

        if isinstance(node.iter, (ast.Tuple, ast.List)):
            elts_count = len(node.iter.elts)
            total_count = node_count * elts_count
            issmall = total_count < LoopFullUnrolling.MAX_NODE_COUNT
            if issmall:
                self.update = True
                return sum([
                    unroll(elt, dc(node.body, i, elts_count))
                    for i, elt in enumerate(node.iter.elts)
                ], [])
        ast.fix_missing_locations(node.iter)
        code = compile(ast.gast_to_ast(ast.Expression(node.iter)),
                       '<loop unrolling>', 'eval')
        try:
            values = list(eval(code, {'builtins': __import__('builtins')}))
        except Exception:
            return node

        values_count = len(values)
        total_count = node_count * values_count
        issmall = total_count < LoopFullUnrolling.MAX_NODE_COUNT
        if issmall:
            try:
                new_node = sum([
                    unroll(to_ast(elt), dc(node.body, i, values_count))
                    for i, elt in enumerate(values)
                ], [])
                self.update = True
                return new_node
            except Exception:
                return node
        return node