Exemple #1
0
    def visit_Subscript(self, node):
        """
        >>> import gast as ast
        >>> from pythran import passmanager, backend
        >>> pm = passmanager.PassManager("test")

        >>> node = ast.parse("def foo(a): a[1:][3]")
        >>> _, node = pm.apply(PartialConstantFolding, node)
        >>> _, node = pm.apply(ConstantFolding, node)
        >>> print(pm.dump(backend.Python, node))
        def foo(a):
            a[4]

        >>> node = ast.parse("def foo(a): a[::2][3]")
        >>> _, node = pm.apply(PartialConstantFolding, node)
        >>> _, node = pm.apply(ConstantFolding, node)
        >>> print(pm.dump(backend.Python, node))
        def foo(a):
            a[6]

        >>> node = ast.parse("def foo(a): a[-4:][5]")
        >>> _, node = pm.apply(PartialConstantFolding, node)
        >>> _, node = pm.apply(ConstantFolding, node)
        >>> print(pm.dump(backend.Python, node))
        def foo(a):
            a[1]
        """
        self.generic_visit(node)
        if not isinstance(node.value, ast.Subscript):
            return node
        if not isinstance(node.value.slice, ast.Slice):
            return node
        if not isinstance(node.slice, ast.Index):
            return node

        if not isnum(node.slice.value):
            return node

        slice_ = node.value.slice
        index = node.slice
        node = node.value

        node.slice = index
        lower = slice_.lower or ast.Constant(0, None)
        step = slice_.step or ast.Constant(1, None)
        node.slice.value = ast.BinOp(lower,
                                     ast.Add(),
                                     ast.BinOp(index.value,
                                               ast.Mult(),
                                               step))
        self.update = True
        return node
    def visit_ListComp(self, node):
        def makeattr(*args):
            r = ast.Attribute(value=ast.Name(id='builtins',
                                             ctx=ast.Load(),
                                             annotation=None,
                                             type_comment=None),
                              attr='map',
                              ctx=ast.Load())
            r = ast.Call(r, list(args), [])
            r = ast.Call(
                ast.Attribute(ast.Name('builtins', ast.Load(), None, None),
                              'list', ast.Load()), [r], [])
            return r

        if isinstance(node.elt, ast.Constant) and len(node.generators) == 1:
            gen = node.generators[0]
            if not gen.ifs and isinstance(gen.iter, ast.Call):
                try:
                    path = attr_to_path(gen.iter.func)[1]
                    range_path = 'pythonic', 'builtins', 'functor', 'range'
                    if path == range_path and len(gen.iter.args) == 1:
                        self.update = True
                        return ast.BinOp(
                            ast.List([node.elt], ast.Load()), ast.Mult(),
                            ast.Call(path_to_attr(('builtins', 'len')),
                                     [gen.iter], []))
                except TypeError:
                    pass

        return self.visitComp(node, makeattr)
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Exemple #4
0
    def visit_JoinedStr(self, node):
        if len(node.values) == 1 and not isinstance(node.values[0],
                                                    ast.FormattedValue):
            # f-strings with no reference to variable (like `f"bar"`, see #1767)
            return node.values[0]

        if not any(
                isinstance(value, ast.FormattedValue)
                for value in node.values):
            # nothing to do (not a f-string)
            return node

        base_str = ""
        elements = []
        for value in node.values:
            if isinstance(value, ast.Constant):
                base_str += value.value.replace("%", "%%")
            elif isinstance(value, ast.FormattedValue):
                base_str += "%"
                if value.format_spec is None:
                    raise PythranSyntaxError(
                        "f-strings without format specifier not supported",
                        value)
                base_str += value.format_spec.values[0].value
                elements.append(value.value)
            else:
                raise NotImplementedError

        return ast.BinOp(
            left=ast.Constant(value=base_str, kind=None),
            op=ast.Mod(),
            right=ast.Tuple(elts=elements, ctx=ast.Load()),
        )
Exemple #5
0
    def inlineFixedSizeArrayBinOp(self, node):

        alike = ast.List, ast.Tuple, ast.Num
        if isinstance(node.left, alike) and isinstance(node.right, alike):
            return node

        lbase, lsize = self.fixedSizeArray(node.left)
        rbase, rsize = self.fixedSizeArray(node.right)
        if not lbase or not rbase:
            return node

        if rsize != 1 and lsize != 1 and rsize != lsize:
            raise PythranSyntaxError("Invalid numpy broadcasting", node)

        self.update = True

        operands = [
            ast.BinOp(self.make_array_index(lbase, lsize, i),
                      type(node.op)(), self.make_array_index(rbase, rsize, i))
            for i in range(max(lsize, rsize))
        ]
        res = ast.Call(path_to_attr(('numpy', 'array')),
                       [ast.Tuple(operands, ast.Load())], [])
        self.aliases[res.func] = {path_to_node(('numpy', 'array'))}
        return res
Exemple #6
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        class ShouldBeReplaced(object):
            pass

        node = templates.replace(
            template,
            block=[
                gast.Assign(
                    [
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None)
                    ],
                    gast.BinOp(
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None), gast.Add(),
                        gast.Constant(1, kind=None)),
                ),
            ] * 2)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn(1))
Exemple #7
0
    def infer_AugAssign(self, node):
        # AugAssign(expr target, operator op, expr value)
        binop = gast.BinOp(node.target, node.op, node.value)
        if hasattr(node, 'lineno'):
            setattr(binop, 'lineno', node.lineno)
        ty_val = self.infer_expr(binop)
        ty_target = self.infer_expr(node.target)
        del self.nodetype[binop]
        if ty_target.is_mutable():
            unify(ty_target, ty_val)

        if isinstance(node.target, gast.Name):
            if ty_target.is_mutable():
                self.tyenv[node.target.id] = ty_val
            else:
                self.tyenv[node.target.id] = copy_ty(ty_val)

        if isinstance(node.target, gast.Attribute):
            ty_obj = self.nodetype[node.target.value]
            assert isinstance(ty_obj, TyUserDefinedClass)
            if ty_target.is_mutable():
                self.attribute_tyenv[(ty_obj.instance, node.target.attr)] = \
                        ty_val
            else:
                self.attribute_tyenv[(ty_obj.instance, node.target.attr)] = \
                        copy_ty(ty_val)

        self.nodetype[node.target] = ty_val
Exemple #8
0
 def to_ast(self):
     assert self._finalized
     if self._argspec:
         result = self._argspec[0]
         for i in range(1, len(self._argspec)):
             result = gast.BinOp(result, gast.Add(), self._argspec[i])
         return result
     return gast.Tuple([], None)
Exemple #9
0
 def visit_AugAssign(self, node):
   self.trivializing = True
   left = self.trivialize(node.target)
   right = self.trivialize(node.value)
   self.trivializing = False
   node = gast.Assign(targets=[node.target],
                      value=gast.BinOp(left=left, op=node.op, right=right))
   return node
 def visit_BoolOp(self, node):
     values = list(node.values)
     self.generic_visit(node)
     if any(x != y for x, y in zip(values, node.values)):
         return reduce(
             lambda x, y: ast.BinOp(x, NormalizeIsNone.table[type(node.op)]
                                    (), y), node.values)
     else:
         return node
Exemple #11
0
 def visit_BoolOp(self, node):
     values = list(node.values)
     self.generic_visit(node)
     if any(x != y for x, y in zip(values, node.values)):
         left, right = node.values
         return ast.BinOp(left, NormalizeIsNone.table[type(node.op)](),
                          right)
     else:
         return node
Exemple #12
0
class SqrPattern(Pattern):
    # X * X => X ** 2
    pattern = ast.BinOp(left=Placeholder(0),
                        op=ast.Mult(),
                        right=Placeholder(0))

    @staticmethod
    def sub():
        return ast.BinOp(left=Placeholder(0), op=ast.Pow(),
                         right=ast.Constant(2, None))
Exemple #13
0
class StrJoinPattern(Pattern):
    # a + "..." + b => "...".join((a, b))
    pattern = ast.BinOp(left=ast.BinOp(left=Placeholder(0),
                                       op=ast.Add(),
                                       right=ast.Constant(
                                           Placeholder(1, str), None)),
                        op=ast.Add(),
                        right=Placeholder(2))

    @staticmethod
    def sub():
        return ast.Call(
            func=ast.Attribute(
                ast.Attribute(ast.Name('__builtin__', ast.Load(), None, None),
                              'str', ast.Load()), 'join', ast.Load()),
            args=[
                ast.Constant(Placeholder(1), None),
                ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())
            ],
            keywords=[])
Exemple #14
0
 def sub():
     return ast.Call(
         func=ast.Attribute(value=ast.Name(id='__builtin__',
                                           ctx=ast.Load(), annotation=None,
                                           type_comment=None),
                            attr=range_name, ctx=ast.Load()),
         args=[ast.BinOp(left=Placeholder(0), op=ast.Sub(),
                         right=ast.Constant(1, None)),
               ast.Constant(-1, None),
               ast.Constant(-1, None)],
         keywords=[])
Exemple #15
0
 def expand_pow(self, node, n):
     if n == 0:
         return ast.Num(1)
     elif n == 1:
         return node
     else:
         node_square = self.replace(node)
         node_pow = self.expand_pow(node_square, n >> 1)
         if n & 1:
             return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node))
         else:
             return node_pow
Exemple #16
0
 def _build_cond_stmt(self, step_node, compare_node):
     return gast.Compare(
         left=gast.BinOp(
             left=gast.Name(
                 id=self.iter_var_name
                 if self.is_for_range_iter() else self.iter_idx_name,
                 ctx=gast.Load(),
                 annotation=None,
                 type_comment=None),
             op=gast.Add(),
             right=step_node),
         ops=[gast.LtE()],
         comparators=[compare_node])
Exemple #17
0
    def get_for_args_stmts(self, iter_name, args_list):
        '''
        Returns 3 gast stmt nodes for argument.
        1. Initailize of iterate variable
        2. Condition for the loop
        3. Statement for changing of iterate variable during the loop
        NOTE(TODO): Python allows to access iteration variable after loop, such
           as "for i in range(10)" will create i = 9 after the loop. But using
           current conversion will make i = 10. We should find a way to change it
        '''
        len_range_args = len(args_list)
        assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
        if len_range_args == 1:
            init_stmt = get_constant_variable_node(iter_name, 0)
        else:
            init_stmt = gast.Assign(
                targets=[
                    gast.Name(
                        id=iter_name,
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=args_list[0])

        range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
        step_node = args_list[2] if len_range_args == 3 else gast.Constant(
            value=1, kind=None)

        cond_stmt = gast.Compare(
            left=gast.BinOp(
                left=gast.Name(
                    id=iter_name,
                    ctx=gast.Load(),
                    annotation=None,
                    type_comment=None),
                op=gast.Add(),
                right=step_node),
            ops=[gast.LtE()],
            comparators=[range_max_node])

        change_stmt = gast.AugAssign(
            target=gast.Name(
                id=iter_name,
                ctx=gast.Store(),
                annotation=None,
                type_comment=None),
            op=gast.Add(),
            value=step_node)

        return init_stmt, cond_stmt, change_stmt
Exemple #18
0
 def visit_AugAssign(self, node):
     self.src = quoting.unquote(node)
     self.trivializing = True
     self.namer.target = node.target
     right = self.trivialize(node.value)
     target = self.trivialize(node.target)
     left = gast.Name(id=target.id, ctx=gast.Load(), annotation=None)
     node = gast.Assign(targets=[target],
                        value=gast.BinOp(left=left, op=node.op,
                                         right=right))
     self.mark(node)
     node = self.generic_visit(node)
     self.namer.target = None
     self.trivializing = False
     return node
Exemple #19
0
    def test_code_block(self):
        def template(block):  # pylint:disable=unused-argument
            def test_fn(a):  # pylint:disable=unused-variable
                block  # pylint:disable=pointless-statement
                return a

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', gast.Store(), None)],
                            gast.BinOp(gast.Name('a', gast.Load(), None),
                                       gast.Add(), gast.Num(1))),
            ] * 2)[0]
        result = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
Exemple #20
0
class CbrtPattern(Pattern):
    # X ** .33333 => numpy.cbrt(X)
    pattern = ast.BinOp(Placeholder(0), ast.Pow(), ast.Constant(1./3., None))

    @staticmethod
    def sub():
        return ast.Call(
            func=ast.Attribute(value=ast.Name(id=mangle('numpy'),
                                              ctx=ast.Load(),
                                              annotation=None,
                                              type_comment=None),
                               attr="cbrt", ctx=ast.Load()),
            args=[Placeholder(0)], keywords=[])

    extra_imports = [ast.Import([ast.alias('numpy', mangle('numpy'))])]
Exemple #21
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_source = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Exemple #23
0
    def get_for_args_stmts(self, iter_name, args_list):
        '''
        Returns 3 gast stmt nodes for argument.
        1. Initailize of iterate variable
        2. Condition for the loop
        3. Statement for changing of iterate variable during the loop
        '''
        len_range_args = len(args_list)
        assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
        if len_range_args == 1:
            init_stmt = get_constant_variable_node(iter_name, 0)
        else:
            init_stmt = gast.Assign(targets=[
                gast.Name(id=iter_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                    value=args_list[0])

        range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
        step_node = args_list[2] if len_range_args == 3 else gast.Constant(
            value=1, kind=None)

        old_cond_stmt = gast.Compare(left=gast.BinOp(left=gast.Name(
            id=iter_name, ctx=gast.Load(), annotation=None, type_comment=None),
                                                     op=gast.Add(),
                                                     right=step_node),
                                     ops=[gast.LtE()],
                                     comparators=[range_max_node])
        cond_stmt = gast.BoolOp(op=gast.And(),
                                values=[old_cond_stmt, self.condition_node])

        change_stmt = gast.AugAssign(target=gast.Name(id=iter_name,
                                                      ctx=gast.Store(),
                                                      annotation=None,
                                                      type_comment=None),
                                     op=gast.Add(),
                                     value=step_node)

        return init_stmt, cond_stmt, change_stmt
Exemple #24
0
    def visit_BinOp(self, node):
        node = self.generic_visit(node)
        if not isinstance(node.op, ast.Mod):
            return node

        right_range = self.range_values[node.right]
        left_range = self.range_values[node.left]

        if right_range.low < 0 or isinf(right_range.high):
            return node

        if left_range.low < -right_range.low:
            return node
        if left_range.high > right_range.high * 2:
            return node

        cleft0, cleft1 = deepcopy(node.left), deepcopy(node.left)
        cright = deepcopy(node.right)
        self.update = True
        return ast.IfExp(ast.Compare(node.left, [ast.Lt()], [node.right]),
                         cleft0, ast.BinOp(cleft1, ast.Sub(), cright))
Exemple #25
0
def eval_ast_impl(nast, env):
    if isinstance(nast, list):
        # 逐次実行
        for s in nast:
            if is_print_logging(s, env):
                continue
            eval_ast(s, env)
        return None
    elif isinstance(nast, gast.For):
        return eval_for(nast, env)

    elif isinstance(nast, gast.Assign):
        return eval_assign(nast, env)

    elif isinstance(nast, gast.AugAssign):
        # referenceへの代入に対してこれは不正確
        ca = gast.Assign(targets=[nast.target],
                         value=gast.BinOp(left=nast.target,
                                          op=nast.op,
                                          right=nast.value))
        return eval_ast(ca, env)

    elif isinstance(nast, gast.Call):
        return eval_call(nast, env)

    elif isinstance(nast, gast.UnaryOp):
        return eval_unary_op(nast, env)

    elif isinstance(nast, gast.BinOp):
        return eval_binary_op(nast, env)

    elif isinstance(nast, gast.BoolOp):
        # 現在は定数boleanのみ対応
        vs = list(map(lambda x: eval_ast(x, env), nast.values))
        res = new_tensor()
        if isinstance(nast.op, gast.And):

            def opfun(v):
                return all(v)
        else:
            raise Exception('unknown operator', nast.op)

        if not any(map(istensor, vs)):
            return opfun(vs)

        raise Exception('Unimplemented BoolOp for tensor', nast)

    elif isinstance(nast, gast.Attribute):
        return eval_attribute(nast, env)

    elif isinstance(nast, gast.Compare):
        return eval_compare(nast, env)

    elif isinstance(nast, gast.If):
        return eval_if(nast, env)

    elif isinstance(nast, gast.ListComp):
        return eval_list_comp(nast, env)

    elif isinstance(nast, gast.Subscript):
        return eval_subscript(nast, env)

    elif isinstance(nast, gast.Delete):
        # おのおの単に忘れる
        vs = nast.targets
        for v in vs:
            assert isinstance(v, gast.Name)
            env.pop_var(v.id)
        return None

    elif isinstance(nast, gast.Name):
        try:
            return env.get_var(nast.id)
        except NameError as ne:
            if nast.id in dir(env.module):
                return getattr(env.module, nast.id)
            elif nast.id in dir(builtins):
                return getattr(builtins, nast.id)
            raise
    elif isinstance(nast, gast.Constant):
        return nast.value
    elif isinstance(nast, gast.Expr):
        return eval_ast(nast.value, env)
    elif isinstance(nast, gast.Constant) and isinstance(nast.value, str):
        return nast.value
    elif isinstance(nast, gast.Tuple):
        return tuple(map(lambda x: eval_ast(x, env), nast.elts))
    elif isinstance(nast, gast.List):
        return eval_list(nast, env)

    elif isinstance(nast, gast.Return):
        raise ValueReturn(eval_ast(nast.value, env))

    elif isinstance(nast, gast.Assert):
        # TODO(hamaji): Emit an assertion?
        return None

    # TODO(hamaji): Implement `with`.
    # elif isinstance(nast, gast.With):
    #     sys.stderr.write(
    #         'WARNING: Currenctly, the context of `with` is just ignored\n')
    #     for s in nast.body:
    #         eval_ast(s, env)
    #     return None

    else:
        print('unknown ast')
        code.InteractiveConsole({'nast': nast, 'env': env}).interact()
        raise Exception('unknown ast', nast)

    raise Exception("shouldn't reach here", nast)
Exemple #26
0
                  ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__',
                                                             ctx=ast.Load(),
                                                             annotation=None),
                                              attr="xrange",
                                              ctx=ast.Load()),
                           args=[Placeholder(0)],
                           keywords=[])
              ],
              keywords=[]),
     lambda: ast.Call(func=ast.Attribute(value=ast.Name(
         id='__builtin__', ctx=ast.Load(), annotation=None),
                                         attr="xrange",
                                         ctx=ast.Load()),
                      args=[
                          ast.BinOp(left=Placeholder(0),
                                    op=ast.Sub(),
                                    right=ast.Num(n=1)),
                          ast.Num(n=-1),
                          ast.Num(n=-1)
                      ],
                      keywords=[])),

    # X * X => X ** 2
    (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)),
     lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))),

    # a + "..." + b => "...".join((a, b))
    (ast.BinOp(left=ast.BinOp(left=Placeholder(0),
                              op=ast.Add(),
                              right=ast.Str(Placeholder(1))),
               op=ast.Add(),
Exemple #27
0
class Square(Transformation):
    """
    Replaces **2 by a call to numpy.square.

    >>> import gast as ast
    >>> from pythran import passmanager, backend
    >>> node = ast.parse('a**2')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print pm.dump(backend.Python, node)
    import numpy
    numpy.square(a)
    >>> node = ast.parse('numpy.power(a,2)')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print pm.dump(backend.Python, node)
    import numpy
    numpy.square(a)
    """

    POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Num(2))
    POWER_PATTERN = ast.Call(
        ast.Attribute(ast.Name('numpy', ast.Load(), None), 'power',
                      ast.Load()), [AST_any(), ast.Num(2)], [])

    def __init__(self):
        Transformation.__init__(self)

    def replace(self, value):
        self.update = self.need_import = True
        return ast.Call(
            ast.Attribute(ast.Name('numpy', ast.Load(), None), 'square',
                          ast.Load()), [value], [])

    def visit_Module(self, node):
        self.need_import = False
        self.generic_visit(node)
        if self.need_import:
            importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)])
            node.body.insert(0, importIt)
        return node

    def expand_pow(self, node, n):
        if n == 0:
            return ast.Num(1)
        elif n == 1:
            return node
        else:
            node_square = self.replace(node)
            node_pow = self.expand_pow(node_square, n >> 1)
            if n & 1:
                return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node))
            else:
                return node_pow

    def visit_BinOp(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POW_PATTERN).search(node):
            return self.replace(node.left)
        elif isinstance(node.op, ast.Pow) and isinstance(node.right, ast.Num):
            n = node.right.n
            if int(n) == n and n > 0:
                return self.expand_pow(node.left, n)
            else:
                return node
        else:
            return node

    def visit_Call(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POWER_PATTERN).search(node):
            return self.replace(node.args[0])
        else:
            return node
    def visit_Cond(self, node):
        '''
        generic expression splitting algorithm. Should work for ifexp and if
        using W(rap) and U(n)W(rap) to manage difference between expr and stmt

        The idea is to split a BinOp in three expressions:
            1. a (possibly empty) non-static expr
            2. an expr containing a static expr
            3. a (possibly empty) non-static expr
        Once split, the if body is refactored to keep the semantic,
        and then recursively split again, until all static expr are alone in a
        test condition
        '''
        NodeTy = type(node)
        if NodeTy is ast.IfExp:
            def W(x):
                return x

            def UW(x):
                return x
        else:
            def W(x):
                return [x]

            def UW(x):
                return x[0]

        has_static_expr = self.gather(HasStaticExpression, node.test)

        if not has_static_expr:
            return self.generic_visit(node)

        if node.test in self.static_expressions:
            return self.generic_visit(node)

        if not isinstance(node.test, ast.BinOp):
            return self.generic_visit(node)

        before, static = [], []
        values = [node.test.right, node.test.left]

        def has_static_expression(n):
            return self.gather(HasStaticExpression, n)

        while values and not has_static_expression(values[-1]):
            before.append(values.pop())

        while values and has_static_expression(values[-1]):
            static.append(values.pop())

        after = list(reversed(values))

        test_before = NodeTy(None, None, None)
        if before:
            assert len(before) == 1
            test_before.test = before[0]

        test_static = NodeTy(None, None, None)
        if static:
            test_static.test = static[0]
            if len(static) > 1:
                if after:
                    assert len(after) == 1
                    after = [ast.BinOp(static[1], node.test.op, after[0])]
                else:
                    after = static[1:]

        test_after = NodeTy(None, None, None)
        if after:
            assert len(after) == 1
            test_after.test = after[0]

        if isinstance(node.test.op, ast.BitAnd):
            if after:
                test_after.body = deepcopy(node.body)
                test_after.orelse = deepcopy(node.orelse)
                test_after = W(test_after)
            else:
                test_after = deepcopy(node.body)

            if static:
                test_static.body = test_after
                test_static.orelse = deepcopy(node.orelse)
                test_static = W(test_static)
            else:
                test_static = test_after

            if before:
                test_before.body = test_static
                test_before.orelse = node.orelse
                node = test_before
            else:
                node = UW(test_static)

        elif isinstance(node.test.op, ast.BitOr):
            if after:
                test_after.body = deepcopy(node.body)
                test_after.orelse = deepcopy(node.orelse)
                test_after = W(test_after)
            else:
                test_after = deepcopy(node.orelse)

            if static:
                test_static.body = deepcopy(node.body)
                test_static.orelse = test_after
                test_static = W(test_static)
            else:
                test_static = test_after

            if before:
                test_before.body = deepcopy(node.body)
                test_before.orelse = test_static
                node = test_before
            else:
                node = UW(test_static)
        else:
            raise PythranSyntaxError("operator not supported in a static if",
                                     node)

        self.update = True
        return self.generic_visit(node)
Exemple #29
0
def analyse(node, env, non_generic=None):
    """Computes the type of the expression given by node.

    The type of the node is computed in the context of the context of the
    supplied type environment env. Data types can be introduced into the
    language simply by having a predefined set of identifiers in the initial
    environment. Environment; this way there is no need to change the syntax
    or more importantly, the type-checking program when extending the language.

    Args:
        node: The root of the abstract syntax tree.
        env: The type environment is a mapping of expression identifier names
            to type assignments.
        non_generic: A set of non-generic variables, or None

    Returns:
        The computed type of the expression.

    Raises:
        InferenceError: The type of the expression could not be inferred,
        PythranTypeError: InferenceError with user friendly message + location
    """

    if non_generic is None:
        non_generic = set()

    # expr
    if isinstance(node, gast.Name):
        if isinstance(node.ctx, (gast.Store)):
            new_type = TypeVariable()
            non_generic.add(new_type)
            env[node.id] = new_type
        return get_type(node.id, env, non_generic)
    elif isinstance(node, gast.Num):
        if isinstance(node.n, (int, long)):
            return Integer()
        elif isinstance(node.n, float):
            return Float()
        elif isinstance(node.n, complex):
            return Complex()
        else:
            raise NotImplementedError
    elif isinstance(node, gast.Str):
        return Str()
    elif isinstance(node, gast.Compare):
        left_type = analyse(node.left, env, non_generic)
        comparators_type = [analyse(comparator, env, non_generic)
                            for comparator in node.comparators]
        ops_type = [analyse(op, env, non_generic)
                    for op in node.ops]
        prev_type = left_type
        result_type = TypeVariable()
        for op_type, comparator_type in zip(ops_type, comparators_type):
            try:
                unify(Function([prev_type, comparator_type], result_type),
                      op_type)
                prev_type = comparator_type
            except InferenceError:
                raise PythranTypeError(
                    "Invalid comparison, between `{}` and `{}`".format(
                        prev_type,
                        comparator_type
                    ),
                    node)
        return result_type
    elif isinstance(node, gast.Call):
        if is_getattr(node):
            self_type = analyse(node.args[0], env, non_generic)
            attr_name = node.args[1].s
            _, attr_signature = attributes[attr_name]
            attr_type = tr(attr_signature)
            result_type = TypeVariable()
            try:
                unify(Function([self_type], result_type), attr_type)
            except InferenceError:
                if isinstance(prune(attr_type), MultiType):
                    msg = 'no attribute found, tried:\n{}'.format(attr_type)
                else:
                    msg = 'tried {}'.format(attr_type)
                raise PythranTypeError(
                    "Invalid attribute for getattr call with self"
                    "of type `{}`, {}".format(self_type, msg), node)

        else:
            fun_type = analyse(node.func, env, non_generic)
            arg_types = [analyse(arg, env, non_generic) for arg in node.args]
            result_type = TypeVariable()
            try:
                unify(Function(arg_types, result_type), fun_type)
            except InferenceError:
                # recover original type
                fun_type = analyse(node.func, env, non_generic)
                if isinstance(prune(fun_type), MultiType):
                    msg = 'no overload found, tried:\n{}'.format(fun_type)
                else:
                    msg = 'tried {}'.format(fun_type)
                raise PythranTypeError(
                    "Invalid argument type for function call to "
                    "`Callable[[{}], ...]`, {}"
                    .format(', '.join('{}'.format(at) for at in arg_types),
                            msg),
                    node)
        return result_type

    elif isinstance(node, gast.IfExp):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        if is_test_is_none(node.test):
            none_id = node.test.left.id
            body_env = env.copy()
            body_env[none_id] = NoneType
        else:
            none_id = None
            body_env = env

        body_type = analyse(node.body, body_env, non_generic)

        if none_id:
            orelse_env = env.copy()
            if is_option_type(env[none_id]):
                orelse_env[none_id] = prune(env[none_id]).types[0]
            else:
                orelse_env[none_id] = TypeVariable()
        else:
            orelse_env = env

        orelse_type = analyse(node.orelse, orelse_env, non_generic)

        try:
            return merge_unify(body_type, orelse_type)
        except InferenceError:
            raise PythranTypeError(
                "Incompatible types from different branches:"
                "`{}` and `{}`".format(
                    body_type,
                    orelse_type
                ),
                node
            )
    elif isinstance(node, gast.UnaryOp):
        operand_type = analyse(node.operand, env, non_generic)
        op_type = analyse(node.op, env, non_generic)
        result_type = TypeVariable()
        try:
            unify(Function([operand_type], result_type), op_type)
            return result_type
        except InferenceError:
            raise PythranTypeError(
                "Invalid operand for `{}`: `{}`".format(
                    symbol_of[type(node.op)],
                    operand_type
                ),
                node
            )
    elif isinstance(node, gast.BinOp):
        left_type = analyse(node.left, env, non_generic)
        op_type = analyse(node.op, env, non_generic)
        right_type = analyse(node.right, env, non_generic)
        result_type = TypeVariable()
        try:
            unify(Function([left_type, right_type], result_type), op_type)
        except InferenceError:
            raise PythranTypeError(
                "Invalid operand for `{}`: `{}` and `{}`".format(
                    symbol_of[type(node.op)],
                    left_type,
                    right_type),
                node
            )
        return result_type
    elif isinstance(node, gast.Pow):
        return tr(MODULES['numpy']['power'])
    elif isinstance(node, gast.Sub):
        return tr(MODULES['operator_']['sub'])
    elif isinstance(node, (gast.USub, gast.UAdd)):
        return tr(MODULES['operator_']['pos'])
    elif isinstance(node, (gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt,
                           gast.GtE, gast.Is, gast.IsNot)):
        return tr(MODULES['operator_']['eq'])
    elif isinstance(node, (gast.In, gast.NotIn)):
        contains_sig = tr(MODULES['operator_']['contains'])
        contains_sig.types[:-1] = reversed(contains_sig.types[:-1])
        return contains_sig
    elif isinstance(node, gast.Add):
        return tr(MODULES['operator_']['add'])
    elif isinstance(node, gast.Mult):
        return tr(MODULES['operator_']['mul'])
    elif isinstance(node, (gast.Div, gast.FloorDiv)):
        return tr(MODULES['operator_']['floordiv'])
    elif isinstance(node, gast.Mod):
        return tr(MODULES['operator_']['mod'])
    elif isinstance(node, (gast.LShift, gast.RShift)):
        return tr(MODULES['operator_']['lshift'])
    elif isinstance(node, (gast.BitXor, gast.BitAnd, gast.BitOr)):
        return tr(MODULES['operator_']['lshift'])
    elif isinstance(node, gast.List):
        new_type = TypeVariable()
        for elt in node.elts:
            elt_type = analyse(elt, env, non_generic)
            try:
                unify(new_type, elt_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible list element type `{}` and `{}`".format(
                        new_type, elt_type),
                    node
                )
        return List(new_type)
    elif isinstance(node, gast.Set):
        new_type = TypeVariable()
        for elt in node.elts:
            elt_type = analyse(elt, env, non_generic)
            try:
                unify(new_type, elt_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible set element type `{}` and `{}`".format(
                        new_type, elt_type),
                    node
                )
        return Set(new_type)
    elif isinstance(node, gast.Dict):
        new_key_type = TypeVariable()
        for key in node.keys:
            key_type = analyse(key, env, non_generic)
            try:
                unify(new_key_type, key_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible dict key type `{}` and `{}`".format(
                        new_key_type, key_type),
                    node
                )
        new_value_type = TypeVariable()
        for value in node.values:
            value_type = analyse(value, env, non_generic)
            try:
                unify(new_value_type, value_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible dict value type `{}` and `{}`".format(
                        new_value_type, value_type),
                    node
                )
        return Dict(new_key_type, new_value_type)
    elif isinstance(node, gast.Tuple):
        return Tuple([analyse(elt, env, non_generic) for elt in node.elts])
    elif isinstance(node, gast.Index):
        return analyse(node.value, env, non_generic)
    elif isinstance(node, gast.Slice):
        def unify_int_or_none(t, name):
            try:
                unify(t, Integer())
            except InferenceError:
                try:
                    unify(t, NoneType)
                except InferenceError:
                    raise PythranTypeError(
                        "Invalid slice {} type `{}`, expecting int or None"
                        .format(name, t)
                    )
        if node.lower:
            lower_type = analyse(node.lower, env, non_generic)
            unify_int_or_none(lower_type, 'lower bound')
        else:
            lower_type = Integer()
        if node.upper:
            upper_type = analyse(node.upper, env, non_generic)
            unify_int_or_none(upper_type, 'upper bound')
        else:
            upper_type = Integer()
        if node.step:
            step_type = analyse(node.step, env, non_generic)
            unify_int_or_none(step_type, 'step')
        else:
            step_type = Integer()
        return Slice
    elif isinstance(node, gast.ExtSlice):
        return [analyse(dim, env, non_generic) for dim in node.dims]
    elif isinstance(node, gast.NameConstant):
        if node.value is None:
            return env['None']
    elif isinstance(node, gast.Subscript):
        new_type = TypeVariable()
        value_type = prune(analyse(node.value, env, non_generic))
        try:
            slice_type = prune(analyse(node.slice, env, non_generic))
        except PythranTypeError as e:
            raise PythranTypeError(e.msg, node)

        if isinstance(node.slice, gast.ExtSlice):
            nbslice = len(node.slice.dims)
            dtype = TypeVariable()
            try:
                unify(Array(dtype, nbslice), clone(value_type))
            except InferenceError:
                raise PythranTypeError(
                    "Dimension mismatch when slicing `{}`".format(value_type),
                    node)
            return TypeVariable()  # FIXME
        elif isinstance(node.slice, gast.Index):
            # handle tuples in a special way
            isnum = isinstance(node.slice.value, gast.Num)
            if isnum and is_tuple_type(value_type):
                try:
                    unify(prune(prune(value_type.types[0]).types[0])
                          .types[node.slice.value.n],
                          new_type)
                    return new_type
                except IndexError:
                    raise PythranTypeError(
                        "Invalid tuple indexing, "
                        "out-of-bound index `{}` for type `{}`".format(
                            node.slice.value.n,
                            value_type),
                        node)
        try:
            unify(tr(MODULES['operator_']['getitem']),
                  Function([value_type, slice_type], new_type))
        except InferenceError:
            raise PythranTypeError(
                "Invalid subscripting of `{}` by `{}`".format(
                    value_type,
                    slice_type),
                node)
        return new_type
        return new_type
    elif isinstance(node, gast.Attribute):
        from pythran.utils import attr_to_path
        obj, path = attr_to_path(node)
        if obj.signature is typing.Any:
            return TypeVariable()
        else:
            return tr(obj)

    # stmt
    elif isinstance(node, gast.Import):
        for alias in node.names:
            if alias.name not in MODULES:
                raise NotImplementedError("unknown module: %s " % alias.name)
            if alias.asname is None:
                target = alias.name
            else:
                target = alias.asname
            env[target] = tr(MODULES[alias.name])
        return env
    elif isinstance(node, gast.ImportFrom):
        if node.module not in MODULES:
            raise NotImplementedError("unknown module: %s" % node.module)
        for alias in node.names:
            if alias.name not in MODULES[node.module]:
                raise NotImplementedError(
                    "unknown function: %s in %s" % (alias.name, node.module))
            if alias.asname is None:
                target = alias.name
            else:
                target = alias.asname
            env[target] = tr(MODULES[node.module][alias.name])
        return env
    elif isinstance(node, gast.FunctionDef):
        ftypes = []
        for i in range(1 + len(node.args.defaults)):
            old_type = env[node.name]
            new_env = env.copy()
            new_non_generic = non_generic.copy()

            # reset return special variables
            new_env.pop('@ret', None)
            new_env.pop('@gen', None)

            hy = HasYield()
            for stmt in node.body:
                hy.visit(stmt)
            new_env['@gen'] = hy.has_yield

            arg_types = []
            istop = len(node.args.args) - i
            for arg in node.args.args[:istop]:
                arg_type = TypeVariable()
                new_env[arg.id] = arg_type
                new_non_generic.add(arg_type)
                arg_types.append(arg_type)
            for arg, expr in zip(node.args.args[istop:],
                                 node.args.defaults[-i:]):
                arg_type = analyse(expr, new_env, new_non_generic)
                new_env[arg.id] = arg_type

            analyse_body(node.body, new_env, new_non_generic)

            result_type = new_env.get('@ret', NoneType)

            if new_env['@gen']:
                result_type = Generator(result_type)

            ftype = Function(arg_types, result_type)
            ftypes.append(ftype)
        if len(ftypes) == 1:
            ftype = ftypes[0]
            env[node.name] = ftype
        else:
            env[node.name] = MultiType(ftypes)
        return env
    elif isinstance(node, gast.Module):
        analyse_body(node.body, env, non_generic)
        return env
    elif isinstance(node, (gast.Pass, gast.Break, gast.Continue)):
        return env
    elif isinstance(node, gast.Expr):
        analyse(node.value, env, non_generic)
        return env
    elif isinstance(node, gast.Delete):
        for target in node.targets:
            if isinstance(target, gast.Name):
                if target.id in env:
                    del env[target.id]
                else:
                    raise PythranTypeError(
                        "Invalid del: unbound identifier `{}`".format(
                            target.id),
                        node)
            else:
                analyse(target, env, non_generic)
        return env
    elif isinstance(node, gast.Print):
        if node.dest is not None:
            analyse(node.dest, env, non_generic)
        for value in node.values:
            analyse(value, env, non_generic)
        return env
    elif isinstance(node, gast.Assign):
        defn_type = analyse(node.value, env, non_generic)
        for target in node.targets:
            target_type = analyse(target, env, non_generic)
            try:
                unify(target_type, defn_type)
            except InferenceError:
                raise PythranTypeError(
                    "Invalid assignment from type `{}` to type `{}`".format(
                        target_type,
                        defn_type),
                    node)
        return env
    elif isinstance(node, gast.AugAssign):
        # FIMXE: not optimal: evaluates type of node.value twice
        fake_target = deepcopy(node.target)
        fake_target.ctx = gast.Load()
        fake_op = gast.BinOp(fake_target, node.op, node.value)
        gast.copy_location(fake_op, node)
        analyse(fake_op, env, non_generic)

        value_type = analyse(node.value, env, non_generic)
        target_type = analyse(node.target, env, non_generic)

        try:
            unify(target_type, value_type)
        except InferenceError:
            raise PythranTypeError(
                "Invalid update operand for `{}`: `{}` and `{}`".format(
                    symbol_of[type(node.op)],
                    value_type,
                    target_type
                ),
                node
            )
        return env
    elif isinstance(node, gast.Raise):
        return env  # TODO
    elif isinstance(node, gast.Return):
        if env['@gen']:
            return env

        if node.value is None:
            ret_type = NoneType
        else:
            ret_type = analyse(node.value, env, non_generic)
        if '@ret' in env:
            try:
                ret_type = merge_unify(env['@ret'], ret_type)
            except InferenceError:
                raise PythranTypeError(
                    "function may returns with incompatible types "
                    "`{}` and `{}`".format(env['@ret'], ret_type),
                    node
                )

        env['@ret'] = ret_type
        return env
    elif isinstance(node, gast.Yield):
        assert env['@gen']
        assert node.value is not None

        if node.value is None:
            ret_type = NoneType
        else:
            ret_type = analyse(node.value, env, non_generic)
        if '@ret' in env:
            try:
                ret_type = merge_unify(env['@ret'], ret_type)
            except InferenceError:
                raise PythranTypeError(
                    "function may yields incompatible types "
                    "`{}` and `{}`".format(env['@ret'], ret_type),
                    node
                )

        env['@ret'] = ret_type
        return env
    elif isinstance(node, gast.For):
        iter_type = analyse(node.iter, env, non_generic)
        target_type = analyse(node.target, env, non_generic)
        unify(Collection(TypeVariable(), TypeVariable(), TypeVariable(),
                         target_type),
              iter_type)
        analyse_body(node.body, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        return env
    elif isinstance(node, gast.If):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        body_env = env.copy()
        body_non_generic = non_generic.copy()

        if is_test_is_none(node.test):
            none_id = node.test.left.id
            body_env[none_id] = NoneType
        else:
            none_id = None

        analyse_body(node.body, body_env, body_non_generic)

        orelse_env = env.copy()
        orelse_non_generic = non_generic.copy()

        if none_id:
            if is_option_type(env[none_id]):
                orelse_env[none_id] = prune(env[none_id]).types[0]
            else:
                orelse_env[none_id] = TypeVariable()
        analyse_body(node.orelse, orelse_env, orelse_non_generic)

        for var in body_env:
            if var not in env:
                if var in orelse_env:
                    try:
                        new_type = merge_unify(body_env[var], orelse_env[var])
                    except InferenceError:
                        raise PythranTypeError(
                            "Incompatible types from different branches for "
                            "`{}`: `{}` and `{}`".format(
                                var,
                                body_env[var],
                                orelse_env[var]
                            ),
                            node
                        )
                else:
                    new_type = body_env[var]
                env[var] = new_type

        for var in orelse_env:
            if var not in env:
                # may not be unified by the prev loop if a del occured
                if var in body_env:
                    new_type = merge_unify(orelse_env[var], body_env[var])
                else:
                    new_type = orelse_env[var]
                env[var] = new_type

        if none_id:
            try:
                new_type = merge_unify(body_env[none_id], orelse_env[none_id])
            except InferenceError:
                msg = ("Inconsistent types while merging values of `{}` from "
                       "conditional branches: `{}` and `{}`")
                err = msg.format(none_id,
                                 body_env[none_id],
                                 orelse_env[none_id])
                raise PythranTypeError(err, node)
            env[none_id] = new_type

        return env
    elif isinstance(node, gast.While):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        analyse_body(node.body, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        return env
    elif isinstance(node, gast.Try):
        analyse_body(node.body, env, non_generic)
        for handler in node.handlers:
            analyse(handler, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        analyse_body(node.finalbody, env, non_generic)
        return env
    elif isinstance(node, gast.ExceptHandler):
        if(node.name):
            new_type = ExceptionType
            non_generic.add(new_type)
            if node.name.id in env:
                unify(env[node.name.id], new_type)
            else:
                env[node.name.id] = new_type
        analyse_body(node.body, env, non_generic)
        return env
    elif isinstance(node, gast.Assert):
        if node.msg:
            analyse(node.msg, env, non_generic)
        analyse(node.test, env, non_generic)
        return env
    elif isinstance(node, gast.UnaryOp):
        operand_type = analyse(node.operand, env, non_generic)
        return_type = TypeVariable()
        op_type = analyse(node.op, env, non_generic)
        unify(Function([operand_type], return_type), op_type)
        return return_type
    elif isinstance(node, gast.Invert):
        return MultiType([Function([Bool()], Integer()),
                          Function([Integer()], Integer())])
    elif isinstance(node, gast.Not):
        return tr(MODULES['__builtin__']['bool_'])
    elif isinstance(node, gast.BoolOp):
        op_type = analyse(node.op, env, non_generic)
        value_types = [analyse(value, env, non_generic)
                       for value in node.values]

        for value_type in value_types:
            unify(Function([value_type], Bool()),
                  tr(MODULES['__builtin__']['bool_']))

        return_type = TypeVariable()
        prev_type = value_types[0]
        for value_type in value_types[1:]:
            unify(Function([prev_type, value_type], return_type), op_type)
            prev_type = value_type
        return return_type
    elif isinstance(node, (gast.And, gast.Or)):
        x_type = TypeVariable()
        return MultiType([
            Function([x_type, x_type], x_type),
            Function([TypeVariable(), TypeVariable()], TypeVariable()),
        ])

    raise RuntimeError("Unhandled syntax node {0}".format(type(node)))
Exemple #30
0
 def sub():
     return ast.BinOp(left=Placeholder(0),
                      op=ast.Pow(),
                      right=ast.Constant(2, None))