Exemplo n.º 1
0
    def visit_Break(self, node):
        loop_node_index = self._find_ancestor_loop_index(node)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
            node, variable_name, loop_node_index)

        # 3. Add 'if V' for stmts in ancestor blocks between the first one
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index,
                              variable_name)

        # 4. For 'break' add break into condition of the loop.
        assign_false_node = create_fill_constant_node(variable_name, False)
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

        cond_var_node = gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=variable_name,
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None))
        if isinstance(loop_node, gast.While):
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node])
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
            for_to_while = ForToWhileTransformer(parent_node, loop_node,
                                                 cond_var_node)
            for_to_while.transform()
Exemplo n.º 2
0
        def visit_Call(self, node):
            if sys.version_info.minor < 5:
                if node.starargs:
                    star = gast.Starred(self._visit(node.starargs),
                                        gast.Load())
                    gast.copy_location(star, node)
                    starred = [star]
                else:
                    starred = []

                if node.kwargs:
                    kwargs = [gast.keyword(None, self._visit(node.kwargs))]
                else:
                    kwargs = []
            else:
                starred = kwargs = []

            new_node = gast.Call(
                self._visit(node.func),
                self._visit(node.args) + starred,
                self._visit(node.keywords) + kwargs,
            )
            gast.copy_location(new_node, node)
            return new_node
Exemplo n.º 3
0
  def test_unparse(self):
    node = gast.If(
        test=gast.Constant(1, kind=None),
        body=[
            gast.Assign(
                targets=[
                    gast.Name(
                        'a',
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=gast.Name(
                    'b', ctx=gast.Load(), annotation=None, type_comment=None))
        ],
        orelse=[
            gast.Assign(
                targets=[
                    gast.Name(
                        'a',
                        ctx=gast.Store(),
                        annotation=None,
                        type_comment=None)
                ],
                value=gast.Constant('c', kind=None))
        ])

    source = parser.unparse(node, indentation='  ')
    self.assertEqual(
        textwrap.dedent("""
            # coding=utf-8
            if 1:
                a = b
            else:
                a = 'c'
        """).strip(), source.strip())
Exemplo n.º 4
0
 def interprocedural_type_translator(s, n):
     translated_othernode = ast.Name(
         '__fake__', ast.Load(), None, None)
     s.result[translated_othernode] = (
         parametric_type.instanciate(
             s.current,
             [s.result[arg] for arg in n.args]))
     # look for modified argument
     for p, effective_arg in enumerate(n.args):
         formal_arg = args[p]
         if formal_arg.id == node_id:
             translated_node = effective_arg
             break
     try:
         s.combine(translated_node,
                   translated_othernode,
                   op, unary_op, register=True,
                   aliasing_type=True)
     except NotImplementedError:
         pass
         # this may fail when the effective
         # parameter is an expression
     except UnboundLocalError:
         pass
Exemplo n.º 5
0
 def tokenize(s):
     '''A simple contextual "parser" for an OpenMP string'''
     # not completely satisfying if there are strings in if expressions
     out = ''
     par_count = 0
     curr_index = 0
     in_reserved_context = False
     while curr_index < len(s):
         m = re.match(r'^([a-zA-Z_]\w*)', s[curr_index:])
         if m:
             word = m.group(0)
             curr_index += len(word)
             if (in_reserved_context
                     or (par_count == 0 and word in keywords)):
                 out += word
                 in_reserved_context = word in reserved_contex
             else:
                 v = '{}'
                 self.deps.append(ast.Name(word, ast.Load(), None))
                 out += v
         elif s[curr_index] == '(':
             par_count += 1
             curr_index += 1
             out += '('
         elif s[curr_index] == ')':
             par_count -= 1
             curr_index += 1
             out += ')'
             if par_count == 0:
                 in_reserved_context = False
         else:
             if s[curr_index] == ',':
                 in_reserved_context = False
             out += s[curr_index]
             curr_index += 1
     return out
Exemplo n.º 6
0
def size_container_folding(value):
    """
    Convert value to ast expression if size is not too big.

    Converter for sized container.
    """

    def size(x):
        return len(getattr(x, 'flatten', lambda: x)())

    if size(value) < MAX_LEN:
        if isinstance(value, list):
            return ast.List([to_ast(elt) for elt in value], ast.Load())
        elif isinstance(value, tuple):
            return ast.Tuple([to_ast(elt) for elt in value], ast.Load())
        elif isinstance(value, set):
            if value:
                return ast.Set([to_ast(elt) for elt in value])
            else:
                return ast.Call(func=ast.Attribute(
                    ast.Name(mangle('builtins'), ast.Load(), None, None),
                    'set',
                    ast.Load()),
                    args=[],
                    keywords=[])
        elif isinstance(value, dict):
            keys = [to_ast(elt) for elt in value.keys()]
            values = [to_ast(elt) for elt in value.values()]
            return ast.Dict(keys, values)
        elif isinstance(value, np.ndarray):
            return ast.Call(func=ast.Attribute(
                ast.Name(mangle('numpy'), ast.Load(), None, None),
                'array',
                ast.Load()),
                args=[to_ast(totuple(value.tolist())),
                      dtype_to_ast(value.dtype.name)],
                keywords=[])
        else:
            raise ConversionError()
    else:
        raise ToNotEval()
Exemplo n.º 7
0
    def make_dispatcher(static_expr, func_true, func_false,
                        imported_ids):
        dispatcher_args = [static_expr,
                           ast.Name(func_true.name, ast.Load(), None, None),
                           ast.Name(func_false.name, ast.Load(), None, None)]

        dispatcher = ast.Call(
            ast.Attribute(
                ast.Attribute(
                    ast.Name("builtins", ast.Load(), None, None),
                    "pythran",
                    ast.Load()),
                "static_if",
                ast.Load()),
            dispatcher_args, [])

        actual_call = ast.Call(
            dispatcher,
            [ast.Name(ii, ast.Load(), None, None) for ii in imported_ids],
            [])

        return actual_call
Exemplo n.º 8
0
 def visit_ExtSlice(self, node):
     new_dims = self._visit(node.dims)
     new_node = gast.Tuple(new_dims, gast.Load())
     gast.copy_location(new_node, node)
     new_node.end_lineno = new_node.end_col_offset = None
     return new_node
Exemplo n.º 9
0
class PyASTGraphsTest(parameterized.TestCase):
    def test_schema(self):
        """Check that the generated schema is reasonable."""
        # Building should succeed
        schema = py_ast_graphs.SCHEMA

        # Should have one graph node for each AST node, plus sorted helpers.
        seq_helpers = [
            "BoolOp_values-seq-helper",
            "Call_args-seq-helper",
            "For_body-seq-helper",
            "FunctionDef_body-seq-helper",
            "If_body-seq-helper",
            "If_orelse-seq-helper",
            "Module_body-seq-helper",
            "While_body-seq-helper",
            "arguments_args-seq-helper",
        ]
        expected_keys = [
            *py_ast_graphs.PY_AST_SPECS.keys(),
            *seq_helpers,
        ]
        self.assertEqual(list(schema.keys()), expected_keys)

        # Check a few elements
        expected_fundef_in_edges = {
            "parent_in",
            "returns_in",
            "returns_missing",
            "body_in",
            "args_in",
        }
        expected_fundef_out_edges = {
            "parent_out",
            "returns_out",
            "body_out_all",
            "body_out_first",
            "body_out_last",
            "args_out",
        }
        self.assertEqual(set(schema["FunctionDef"].in_edges),
                         expected_fundef_in_edges)
        self.assertEqual(set(schema["FunctionDef"].out_edges),
                         expected_fundef_out_edges)

        expected_expr_in_edges = {"parent_in", "value_in"}
        expected_expr_out_edges = {"parent_out", "value_out"}
        self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges)
        self.assertEqual(set(schema["Expr"].out_edges),
                         expected_expr_out_edges)

        expected_seq_helper_in_edges = {
            "parent_in",
            "item_in",
            "next_in",
            "next_missing",
            "prev_in",
            "prev_missing",
        }
        expected_seq_helper_out_edges = {
            "parent_out",
            "item_out",
            "next_out",
            "prev_out",
        }
        for seq_helper in seq_helpers:
            self.assertEqual(set(schema[seq_helper].in_edges),
                             expected_seq_helper_in_edges)
            self.assertEqual(set(schema[seq_helper].out_edges),
                             expected_seq_helper_out_edges)

    def test_ast_graph_conforms_to_schema(self):
        # Some example code using a few different syntactic constructs, to cover
        # a large set of nodes in the schema
        root = gast.parse(
            textwrap.dedent("""\
        def foo(n):
          if n <= 1:
            return 1
          else:
            return foo(n-1) + foo(n-2)

        def bar(m, n) -> int:
          x = n
          for i in range(m):
            if False:
              continue
            x = x + i
          while True:
            break
          return x

        x0 = 1 + 2 - 3 * 4 / 5
        x1 = (1 == 2) and (3 < 4) and (5 > 6)
        x2 = (7 <= 8) and (9 >= 10) or (11 != 12)
        x2 = bar(13, 14 + 15)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Graph should match the schema
        schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA)

    def test_ast_graph_nodes(self):
        """Check node IDs, node types, and forward mapping."""
        root = gast.parse(
            textwrap.dedent("""\
        pass
        def foo(n):
            if n <= 1:
              return 1
        """))

        graph, forward_map = py_ast_graphs.py_ast_to_graph(root)

        # pytype: disable=attribute-error
        self.assertIn("root__Module", graph)
        self.assertEqual(graph["root__Module"].node_type, "Module")
        self.assertEqual(forward_map[id(root)], "root__Module")

        self.assertIn("root_body_1__Module_body-seq-helper", graph)
        self.assertEqual(
            graph["root_body_1__Module_body-seq-helper"].node_type,
            "Module_body-seq-helper")

        self.assertIn("root_body_1_item_body_0_item__If", graph)
        self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type,
                         "If")
        self.assertEqual(forward_map[id(root.body[1].body[0])],
                         "root_body_1_item_body_0_item__If")

        self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph)
        self.assertEqual(
            graph["root_body_1_item_body_0_item_test_left__Name"].node_type,
            "Name")
        self.assertEqual(forward_map[id(root.body[1].body[0].test.left)],
                         "root_body_1_item_body_0_item_test_left__Name")
        # pytype: enable=attribute-error

    def test_ast_graph_unique_field_edges(self):
        """Test that edges for unique fields are correct."""
        root = gast.parse("print(1)")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Expr"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item_value__Call"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Call"].out_edges["parent_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Expr"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

    def test_ast_graph_optional_field_edges(self):
        """Test that edges for optional fields are correct."""
        root = gast.parse("return 1\nreturn")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId(
                        "root_body_0_item_value__Constant"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Constant"].out_edges["parent_out"],
            [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Return"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

        self.assertEqual(
            graph["root_body_1_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_1_item__Return"),
                    in_edge=graph_types.InEdgeType("value_missing"))
            ])

    def test_ast_graph_sequence_field_edges(self):
        """Test that edges for sequence fields are correct.

    Note that sequence fields produce connections between three nodes: the
    parent, the helper node, and the child.
    """
        root = gast.parse(
            textwrap.dedent("""\
        print(1)
        print(2)
        print(3)
        print(4)
        print(5)
        print(6)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Child edges from the parent node
        node = graph["root__Module"]
        self.assertLen(node.out_edges["body_out_all"], 6)
        self.assertEqual(node.out_edges["body_out_first"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["body_out_last"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_5__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])

        # Edges from the sequence helper
        node = graph["root_body_0__Module_body-seq-helper"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root__Module"),
                in_edge=graph_types.InEdgeType("body_in"))
        ])
        self.assertEqual(node.out_edges["item_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root_body_0_item__Expr"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["prev_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_missing"))
        ])
        self.assertEqual(node.out_edges["next_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_1__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_in"))
        ])

        # Parent edge of the item
        node = graph["root_body_0_item__Expr"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("item_in"))
        ])

    @parameterized.named_parameters(
        {
            "testcase_name": "unexpected_type",
            "ast": gast.Subscript(value=None, slice=None, ctx=None),
            "expected_error": "Unknown AST node type 'Subscript'",
        }, {
            "testcase_name":
            "too_many_unique",
            "ast":
            gast.Assign(targets=[
                gast.Name("foo", gast.Store(), None, None),
                gast.Name("bar", gast.Store(), None, None)
            ],
                        value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 2",
        }, {
            "testcase_name":
            "missing_unique",
            "ast":
            gast.Assign(targets=[], value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 0",
        }, {
            "testcase_name":
            "too_many_optional",
            "ast":
            gast.Return(value=[
                gast.Name("foo", gast.Load(), None, None),
                gast.Name("bar", gast.Load(), None, None)
            ]),
            "expected_error":
            "Expected at most 1 child for field 'value' of node .*; got 2",
        })
    def test_invalid_graphs(self, ast, expected_error):
        with self.assertRaisesRegex(ValueError, expected_error):
            py_ast_graphs.py_ast_to_graph(ast)
Exemplo n.º 10
0
    def visit_Call(self, node):
        """
        Transform call site to have normal function call.

        Examples
        --------
        For methods:
        >> a = [1, 2, 3]

        >> a.append(1)

        Becomes

        >> __list__.append(a, 1)


        For functions:
        >> __builtin__.dict.fromkeys([1, 2, 3])

        Becomes

        >> __builtin__.__dict__.fromkeys([1, 2, 3])
        """
        node = self.generic_visit(node)
        # Only attributes function can be Pythonic and should be normalized
        if isinstance(node.func, ast.Attribute):
            if node.func.attr in methods:
                # Get object targeted by methods
                obj = lhs = node.func.value
                # Get the most left identifier to check if it is not an
                # imported module
                while isinstance(obj, ast.Attribute):
                    obj = obj.value
                is_not_module = (not isinstance(obj, ast.Name)
                                 or obj.id not in self.imports)

                if is_not_module:
                    self.update = True
                    # As it was a methods call, push targeted object as first
                    # arguments and add correct module prefix
                    node.args.insert(0, lhs)
                    mod = methods[node.func.attr][0]
                    # Submodules import full module
                    self.to_import.add(mangle(mod[0]))
                    node.func = reduce(
                        lambda v, o: ast.Attribute(v, o, ast.Load()),
                        mod[1:] + (node.func.attr, ),
                        ast.Name(mangle(mod[0]), ast.Load(), None))
                # else methods have been called using function syntax
            if node.func.attr in methods or node.func.attr in functions:
                # Now, methods and function have both function syntax
                def rec(path, cur_module):
                    """
                    Recursively rename path content looking in matching module.

                    Prefers __module__ to module if it exists.
                    This recursion is done as modules are visited top->bottom
                    while attributes have to be visited bottom->top.
                    """
                    err = "Function path is chained attributes and name"
                    assert isinstance(path, (ast.Name, ast.Attribute)), err
                    if isinstance(path, ast.Attribute):
                        new_node, cur_module = rec(path.value, cur_module)
                        new_id, mname = self.renamer(path.attr, cur_module)
                        return (ast.Attribute(new_node, new_id,
                                              ast.Load()), cur_module[mname])
                    else:
                        new_id, mname = self.renamer(path.id, cur_module)
                        if mname not in cur_module:
                            raise PythranSyntaxError(
                                "Unbound identifier '{}'".format(mname), node)

                        return (ast.Name(new_id, ast.Load(),
                                         None), cur_module[mname])

                # Rename module path to avoid naming issue.
                node.func.value, _ = rec(node.func.value, MODULES)
                self.update = True

        return node
Exemplo n.º 11
0
 def add_stararg(self, a):
     self._consume_args()
     self._argspec.append(
         gast.Call(gast.Name('tuple', gast.Load(), None), [a], ()))
Exemplo n.º 12
0
class Square(Transformation):
    """
    Replaces **2 by a call to numpy.square.

    >>> import gast as ast
    >>> from pythran import passmanager, backend
    >>> node = ast.parse('a**2')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print pm.dump(backend.Python, node)
    import numpy
    numpy.square(a)
    >>> node = ast.parse('numpy.power(a,2)')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print pm.dump(backend.Python, node)
    import numpy
    numpy.square(a)
    """

    POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Num(2))
    POWER_PATTERN = ast.Call(
        ast.Attribute(ast.Name('numpy', ast.Load(), None), 'power',
                      ast.Load()), [AST_any(), ast.Num(2)], [])

    def __init__(self):
        Transformation.__init__(self)

    def replace(self, value):
        self.update = self.need_import = True
        return ast.Call(
            ast.Attribute(ast.Name('numpy', ast.Load(), None), 'square',
                          ast.Load()), [value], [])

    def visit_Module(self, node):
        self.need_import = False
        self.generic_visit(node)
        if self.need_import:
            importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)])
            node.body.insert(0, importIt)
        return node

    def expand_pow(self, node, n):
        if n == 0:
            return ast.Num(1)
        elif n == 1:
            return node
        else:
            node_square = self.replace(node)
            node_pow = self.expand_pow(node_square, n >> 1)
            if n & 1:
                return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node))
            else:
                return node_pow

    def visit_BinOp(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POW_PATTERN).search(node):
            return self.replace(node.left)
        elif isinstance(node.op, ast.Pow) and isinstance(node.right, ast.Num):
            n = node.right.n
            if int(n) == n and n > 0:
                return self.expand_pow(node.left, n)
            else:
                return node
        else:
            return node

    def visit_Call(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POWER_PATTERN).search(node):
            return self.replace(node.args[0])
        else:
            return node
Exemplo n.º 13
0
 def replace(self, value):
     self.update = self.need_import = True
     return ast.Call(
         ast.Attribute(ast.Name('numpy', ast.Load(), None), 'square',
                       ast.Load()), [value], [])
Exemplo n.º 14
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Exemplo n.º 15
0
Arquivo: types.py Projeto: yws/pythran
    def combine_(self, node, othernode, op, unary_op, register):
        try:
            if register:  # this comes from an assignment,
                # so we must check where the value is assigned
                node_id, depth = self.node_to_id(node)
                if depth > 0:
                    node = ast.Name(node_id, ast.Load(), None)
                    former_unary_op = unary_op

                    # update the type to reflect container nesting
                    def unary_op(x):
                        return reduce(lambda t, n: ContainerType(t),
                                      range(depth), former_unary_op(x))

                    # patch the op, as we no longer apply op, but infer content
                    def op(*types):
                        if len(types) == 1:
                            return types[0]
                        else:
                            return CombinedTypes(*types)

                self.name_to_nodes.setdefault(node_id, set()).add(node)

            # only perform inter procedural combination upon stage 0
            if register and self.isargument(node) and self.stage == 0:
                node_id, _ = self.node_to_id(node)
                if node not in self.result:
                    self.result[node] = unary_op(self.result[othernode])
                assert self.result[node], "found an alias with a type"

                parametric_type = PType(self.current, self.result[othernode])
                if self.register(parametric_type):

                    current_function = self.combiners[self.current]

                    def translator_generator(args, op, unary_op):
                        ''' capture args for translator generation'''
                        def interprocedural_type_translator(s, n):
                            translated_othernode = ast.Name(
                                '__fake__', ast.Load(), None)
                            s.result[translated_othernode] = (
                                parametric_type.instanciate(
                                    s.current,
                                    [s.result[arg] for arg in n.args]))
                            # look for modified argument
                            for p, effective_arg in enumerate(n.args):
                                formal_arg = args[p]
                                if formal_arg.id == node_id:
                                    translated_node = effective_arg
                                    break
                            try:
                                s.combine(translated_node,
                                          translated_othernode,
                                          op,
                                          unary_op,
                                          register=True,
                                          aliasing_type=True)
                            except NotImplementedError:
                                pass
                                # this may fail when the effective
                                # parameter is an expression
                            except UnboundLocalError:
                                pass
                                # this may fail when translated_node
                                # is a default parameter

                        return interprocedural_type_translator

                    translator = translator_generator(
                        self.current.args.args, op,
                        unary_op)  # deferred combination
                    current_function.add_combiner(translator)
            else:
                new_type = unary_op(self.result[othernode])
                if node not in self.result or self.result[node] is UnknownType:
                    self.result[node] = new_type
                else:
                    self.result[node] = op(self.result[node], new_type)
        except UnboundableRValue:
            pass
Exemplo n.º 16
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Exemplo n.º 17
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Exemplo n.º 18
0
    def get_while_stmt_nodes(self, node):
        # TODO: consider while - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # while x < 10 in dygraph should be convert into static tensor < 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        logical_op_transformer = LogicalOpTransformer(node.test)
        cond_value_node = logical_op_transformer.transform()

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_value_node)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)
        return new_stmts
Exemplo n.º 19
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python
        if not self.name_visitor.is_control_flow_loop(node):
            return [node]

        # TODO: support non-range case
        range_call_node = self.get_for_range_node(node)
        if range_call_node is None:
            return [node]

        if not isinstance(node.target, gast.Name):
            return [node]
        iter_var_name = node.target.id

        init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
            iter_var_name, range_call_node.args)

        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        new_stmts.append(init_stmt)

        # for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
        for name in loop_var_names:
            new_stmts.append(to_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(change_stmt)
        new_body.append(
            gast.Return(
                value=generate_name_node(loop_var_names, ctx=gast.Load())))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(args=[
                gast.Name(id=name,
                          ctx=gast.Param(),
                          annotation=None,
                          type_comment=None) for name in loop_var_names
            ],
                                posonlyargs=[],
                                vararg=None,
                                kwonlyargs=[],
                                kw_defaults=None,
                                kwarg=None,
                                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_node = create_while_node(condition_func_node.name,
                                            body_func_node.name,
                                            loop_var_names)
        new_stmts.append(while_loop_node)

        return new_stmts
Exemplo n.º 20
0
 def makeattr():
     return ast.Attribute(
         value=ast.Name(id='__builtin__',
                        ctx=ast.Load(),
                        annotation=None),
         attr='map', ctx=ast.Load())
Exemplo n.º 21
0
""" Optimization for Python costly pattern. """

from pythran.conversion import mangle
from pythran.analyses import Check, Placeholder
from pythran.passmanager import Transformation

import gast as ast

# Tuple of : (pattern, replacement)
# replacement have to be a lambda function to have a new ast to replace when
# replacement is inserted in main ast
know_pattern = [
    # __builtin__.len(__builtin__.set(X)) => __builtin__.pythran.len_set(X)
    (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__',
                                                ctx=ast.Load(),
                                                annotation=None),
                                 attr="len",
                                 ctx=ast.Load()),
              args=[
                  ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__',
                                                             ctx=ast.Load(),
                                                             annotation=None),
                                              attr="set",
                                              ctx=ast.Load()),
                           args=[Placeholder(0)],
                           keywords=[])
              ],
              keywords=[]),
     lambda: ast.Call(func=ast.Attribute(value=ast.Attribute(value=ast.Name(
         id='__builtin__', ctx=ast.Load(), annotation=None),
                                                             attr="pythran",
Exemplo n.º 22
0
 def _consume_args(self):
     if self._arg_accumulator:
         self._argspec.append(
             gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
         self._arg_accumulator = []
Exemplo n.º 23
0
    def combine_(self, node, othernode, op, unary_op, register):
        try:
            # This comes from an assignment,so we must check where the value is
            # assigned
            if register:
                try:
                    node_id, depth = self.node_to_id(node)
                    if depth:
                        node = ast.Name(node_id, ast.Load(), None, None)
                        former_unary_op = unary_op

                        # update the type to reflect container nesting
                        def merge_container_type(ty, index):
                            # integral index make it possible to correctly
                            # update tuple type
                            if isinstance(index, int):
                                kty = self.builder.NamedType(
                                    'std::integral_constant<long,{}>'.format(
                                        index))
                                return self.builder.IndexableContainerType(
                                    kty, ty)
                            else:
                                return self.builder.ContainerType(ty)

                        def unary_op(x):
                            return reduce(merge_container_type, depth,
                                          former_unary_op(x))

                        # patch the op, as we no longer apply op,
                        # but infer content
                        op = self.combined

                    self.name_to_nodes[node_id].append(node)
                except UnboundableRValue:
                    pass

            # only perform inter procedural combination upon stage 0
            if register and self.isargument(node) and self.stage == 0:
                node_id, _ = self.node_to_id(node)
                if node not in self.result:
                    self.result[node] = unary_op(self.result[othernode])
                assert self.result[node], "found an alias with a type"

                parametric_type = self.builder.PType(self.current,
                                                     self.result[othernode])
                if self.register(parametric_type):

                    current_function = self.combiners[self.current]

                    def translator_generator(args, op, unary_op):
                        ''' capture args for translator generation'''
                        def interprocedural_type_translator(s, n):
                            translated_othernode = ast.Name(
                                '__fake__', ast.Load(), None, None)
                            s.result[translated_othernode] = (
                                parametric_type.instanciate(
                                    s.current,
                                    [s.result[arg] for arg in n.args]))
                            # look for modified argument
                            for p, effective_arg in enumerate(n.args):
                                formal_arg = args[p]
                                if formal_arg.id == node_id:
                                    translated_node = effective_arg
                                    break
                            try:
                                s.combine(translated_node,
                                          translated_othernode,
                                          op,
                                          unary_op,
                                          register=True,
                                          aliasing_type=True)
                            except NotImplementedError:
                                pass
                                # this may fail when the effective
                                # parameter is an expression
                            except UnboundLocalError:
                                pass
                                # this may fail when translated_node
                                # is a default parameter

                        return interprocedural_type_translator

                    translator = translator_generator(
                        self.current.args.args, op,
                        unary_op)  # deferred combination
                    current_function.add_combiner(translator)
            else:
                new_type = unary_op(self.result[othernode])
                UnknownType = self.builder.UnknownType
                if node not in self.result or self.result[node] is UnknownType:
                    self.result[node] = new_type
                else:
                    if isinstance(self.result[node], tuple):
                        raise UnboundableRValue
                    self.result[node] = op(self.result[node], new_type)

        except UnboundableRValue:
            pass
Exemplo n.º 24
0
 def sub():
     return ast.Tuple(Placeholder(0), ast.Load())
Exemplo n.º 25
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
Exemplo n.º 26
0
    def visit_If(self, node):
        if node.test not in self.static_expressions:
            return self.generic_visit(node)

        imported_ids = self.gather(ImportedIds, node)

        assigned_ids_left = self.escaping_ids(node, node.body)
        assigned_ids_right = self.escaping_ids(node, node.orelse)
        assigned_ids_both = assigned_ids_left.union(assigned_ids_right)

        imported_ids.update(i for i in assigned_ids_left
                            if i not in assigned_ids_right)
        imported_ids.update(i for i in assigned_ids_right
                            if i not in assigned_ids_left)
        imported_ids = sorted(imported_ids)

        assigned_ids = sorted(assigned_ids_both)

        fbody = self.make_fake(node.body)
        true_has_return = self.gather(HasReturn, fbody)
        true_has_break = self.gather(HasBreak, fbody)
        true_has_cont = self.gather(HasContinue, fbody)

        felse = self.make_fake(node.orelse)
        false_has_return = self.gather(HasReturn, felse)
        false_has_break = self.gather(HasBreak, felse)
        false_has_cont = self.gather(HasContinue, felse)

        has_return = true_has_return or false_has_return
        has_break = true_has_break or false_has_break
        has_cont = true_has_cont or false_has_cont

        self.generic_visit(node)

        func_true = outline(self.true_name(), imported_ids, assigned_ids,
                            node.body, has_return, has_break, has_cont)
        func_false = outline(self.false_name(), imported_ids, assigned_ids,
                             node.orelse, has_return, has_break, has_cont)
        self.new_functions.extend((func_true, func_false))

        actual_call = self.make_dispatcher(node.test,
                                           func_true, func_false, imported_ids)

        # variable modified within the static_if
        expected_return = [ast.Name(ii, ast.Store(), None, None)
                           for ii in assigned_ids]

        self.update = True

        # name for various variables resulting from the static_if
        n = len(self.new_functions)
        status_n = "$status{}".format(n)
        return_n = "$return{}".format(n)
        cont_n = "$cont{}".format(n)

        if has_return:
            cfg = self.cfgs[-1]
            always_return = all(isinstance(x, (ast.Return, ast.Yield))
                                for x in cfg[node])
            always_return &= true_has_return and false_has_return

            fast_return = [ast.Name(status_n, ast.Store(), None, None),
                           ast.Name(return_n, ast.Store(), None, None),
                           ast.Name(cont_n, ast.Store(), None, None)]

            if always_return:
                return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                                   actual_call, None),
                        ast.Return(ast.Name(return_n, ast.Load(), None, None))]
            else:
                cont_ass = self.make_control_flow_handlers(cont_n, status_n,
                                                           expected_return,
                                                           has_cont, has_break)

                cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                                   [ast.Eq()], [ast.Constant(EARLY_RET, None)])
                return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                                   actual_call, None),
                        ast.If(cmpr,
                               [ast.Return(ast.Name(return_n, ast.Load(),
                                                    None, None))],
                               cont_ass)]
        elif has_break or has_cont:
            cont_ass = self.make_control_flow_handlers(cont_n, status_n,
                                                       expected_return,
                                                       has_cont, has_break)

            fast_return = [ast.Name(status_n, ast.Store(), None, None),
                           ast.Name(cont_n, ast.Store(), None, None)]
            return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
                               actual_call, None)] + cont_ass
        elif expected_return:
            return ast.Assign([ast.Tuple(expected_return, ast.Store())],
                              actual_call, None)
        else:
            return ast.Expr(actual_call)
Exemplo n.º 27
0
 def _to_reference_list(self, names):
     return gast.List([self._to_reference(name) for name in names],
                      ctx=gast.Load())
Exemplo n.º 28
0
 def _process_body_item(self, node):
     if isinstance(node, gast.Assign) and (node.value.id == 'y'):
         if_node = gast.If(gast.Name('x', gast.Load(), None),
                           [node], [])
         return if_node, if_node.body
     return node, None
Exemplo n.º 29
0
def analyse(node, env, non_generic=None):
    """Computes the type of the expression given by node.

    The type of the node is computed in the context of the context of the
    supplied type environment env. Data types can be introduced into the
    language simply by having a predefined set of identifiers in the initial
    environment. Environment; this way there is no need to change the syntax
    or more importantly, the type-checking program when extending the language.

    Args:
        node: The root of the abstract syntax tree.
        env: The type environment is a mapping of expression identifier names
            to type assignments.
        non_generic: A set of non-generic variables, or None

    Returns:
        The computed type of the expression.

    Raises:
        InferenceError: The type of the expression could not be inferred,
        PythranTypeError: InferenceError with user friendly message + location
    """

    if non_generic is None:
        non_generic = set()

    # expr
    if isinstance(node, gast.Name):
        if isinstance(node.ctx, (gast.Store)):
            new_type = TypeVariable()
            non_generic.add(new_type)
            env[node.id] = new_type
        return get_type(node.id, env, non_generic)
    elif isinstance(node, gast.Num):
        if isinstance(node.n, (int, long)):
            return Integer()
        elif isinstance(node.n, float):
            return Float()
        elif isinstance(node.n, complex):
            return Complex()
        else:
            raise NotImplementedError
    elif isinstance(node, gast.Str):
        return Str()
    elif isinstance(node, gast.Compare):
        left_type = analyse(node.left, env, non_generic)
        comparators_type = [analyse(comparator, env, non_generic)
                            for comparator in node.comparators]
        ops_type = [analyse(op, env, non_generic)
                    for op in node.ops]
        prev_type = left_type
        result_type = TypeVariable()
        for op_type, comparator_type in zip(ops_type, comparators_type):
            try:
                unify(Function([prev_type, comparator_type], result_type),
                      op_type)
                prev_type = comparator_type
            except InferenceError:
                raise PythranTypeError(
                    "Invalid comparison, between `{}` and `{}`".format(
                        prev_type,
                        comparator_type
                    ),
                    node)
        return result_type
    elif isinstance(node, gast.Call):
        if is_getattr(node):
            self_type = analyse(node.args[0], env, non_generic)
            attr_name = node.args[1].s
            _, attr_signature = attributes[attr_name]
            attr_type = tr(attr_signature)
            result_type = TypeVariable()
            try:
                unify(Function([self_type], result_type), attr_type)
            except InferenceError:
                if isinstance(prune(attr_type), MultiType):
                    msg = 'no attribute found, tried:\n{}'.format(attr_type)
                else:
                    msg = 'tried {}'.format(attr_type)
                raise PythranTypeError(
                    "Invalid attribute for getattr call with self"
                    "of type `{}`, {}".format(self_type, msg), node)

        else:
            fun_type = analyse(node.func, env, non_generic)
            arg_types = [analyse(arg, env, non_generic) for arg in node.args]
            result_type = TypeVariable()
            try:
                unify(Function(arg_types, result_type), fun_type)
            except InferenceError:
                # recover original type
                fun_type = analyse(node.func, env, non_generic)
                if isinstance(prune(fun_type), MultiType):
                    msg = 'no overload found, tried:\n{}'.format(fun_type)
                else:
                    msg = 'tried {}'.format(fun_type)
                raise PythranTypeError(
                    "Invalid argument type for function call to "
                    "`Callable[[{}], ...]`, {}"
                    .format(', '.join('{}'.format(at) for at in arg_types),
                            msg),
                    node)
        return result_type

    elif isinstance(node, gast.IfExp):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        if is_test_is_none(node.test):
            none_id = node.test.left.id
            body_env = env.copy()
            body_env[none_id] = NoneType
        else:
            none_id = None
            body_env = env

        body_type = analyse(node.body, body_env, non_generic)

        if none_id:
            orelse_env = env.copy()
            if is_option_type(env[none_id]):
                orelse_env[none_id] = prune(env[none_id]).types[0]
            else:
                orelse_env[none_id] = TypeVariable()
        else:
            orelse_env = env

        orelse_type = analyse(node.orelse, orelse_env, non_generic)

        try:
            return merge_unify(body_type, orelse_type)
        except InferenceError:
            raise PythranTypeError(
                "Incompatible types from different branches:"
                "`{}` and `{}`".format(
                    body_type,
                    orelse_type
                ),
                node
            )
    elif isinstance(node, gast.UnaryOp):
        operand_type = analyse(node.operand, env, non_generic)
        op_type = analyse(node.op, env, non_generic)
        result_type = TypeVariable()
        try:
            unify(Function([operand_type], result_type), op_type)
            return result_type
        except InferenceError:
            raise PythranTypeError(
                "Invalid operand for `{}`: `{}`".format(
                    symbol_of[type(node.op)],
                    operand_type
                ),
                node
            )
    elif isinstance(node, gast.BinOp):
        left_type = analyse(node.left, env, non_generic)
        op_type = analyse(node.op, env, non_generic)
        right_type = analyse(node.right, env, non_generic)
        result_type = TypeVariable()
        try:
            unify(Function([left_type, right_type], result_type), op_type)
        except InferenceError:
            raise PythranTypeError(
                "Invalid operand for `{}`: `{}` and `{}`".format(
                    symbol_of[type(node.op)],
                    left_type,
                    right_type),
                node
            )
        return result_type
    elif isinstance(node, gast.Pow):
        return tr(MODULES['numpy']['power'])
    elif isinstance(node, gast.Sub):
        return tr(MODULES['operator_']['sub'])
    elif isinstance(node, (gast.USub, gast.UAdd)):
        return tr(MODULES['operator_']['pos'])
    elif isinstance(node, (gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt,
                           gast.GtE, gast.Is, gast.IsNot)):
        return tr(MODULES['operator_']['eq'])
    elif isinstance(node, (gast.In, gast.NotIn)):
        contains_sig = tr(MODULES['operator_']['contains'])
        contains_sig.types[:-1] = reversed(contains_sig.types[:-1])
        return contains_sig
    elif isinstance(node, gast.Add):
        return tr(MODULES['operator_']['add'])
    elif isinstance(node, gast.Mult):
        return tr(MODULES['operator_']['mul'])
    elif isinstance(node, (gast.Div, gast.FloorDiv)):
        return tr(MODULES['operator_']['floordiv'])
    elif isinstance(node, gast.Mod):
        return tr(MODULES['operator_']['mod'])
    elif isinstance(node, (gast.LShift, gast.RShift)):
        return tr(MODULES['operator_']['lshift'])
    elif isinstance(node, (gast.BitXor, gast.BitAnd, gast.BitOr)):
        return tr(MODULES['operator_']['lshift'])
    elif isinstance(node, gast.List):
        new_type = TypeVariable()
        for elt in node.elts:
            elt_type = analyse(elt, env, non_generic)
            try:
                unify(new_type, elt_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible list element type `{}` and `{}`".format(
                        new_type, elt_type),
                    node
                )
        return List(new_type)
    elif isinstance(node, gast.Set):
        new_type = TypeVariable()
        for elt in node.elts:
            elt_type = analyse(elt, env, non_generic)
            try:
                unify(new_type, elt_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible set element type `{}` and `{}`".format(
                        new_type, elt_type),
                    node
                )
        return Set(new_type)
    elif isinstance(node, gast.Dict):
        new_key_type = TypeVariable()
        for key in node.keys:
            key_type = analyse(key, env, non_generic)
            try:
                unify(new_key_type, key_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible dict key type `{}` and `{}`".format(
                        new_key_type, key_type),
                    node
                )
        new_value_type = TypeVariable()
        for value in node.values:
            value_type = analyse(value, env, non_generic)
            try:
                unify(new_value_type, value_type)
            except InferenceError:
                raise PythranTypeError(
                    "Incompatible dict value type `{}` and `{}`".format(
                        new_value_type, value_type),
                    node
                )
        return Dict(new_key_type, new_value_type)
    elif isinstance(node, gast.Tuple):
        return Tuple([analyse(elt, env, non_generic) for elt in node.elts])
    elif isinstance(node, gast.Index):
        return analyse(node.value, env, non_generic)
    elif isinstance(node, gast.Slice):
        def unify_int_or_none(t, name):
            try:
                unify(t, Integer())
            except InferenceError:
                try:
                    unify(t, NoneType)
                except InferenceError:
                    raise PythranTypeError(
                        "Invalid slice {} type `{}`, expecting int or None"
                        .format(name, t)
                    )
        if node.lower:
            lower_type = analyse(node.lower, env, non_generic)
            unify_int_or_none(lower_type, 'lower bound')
        else:
            lower_type = Integer()
        if node.upper:
            upper_type = analyse(node.upper, env, non_generic)
            unify_int_or_none(upper_type, 'upper bound')
        else:
            upper_type = Integer()
        if node.step:
            step_type = analyse(node.step, env, non_generic)
            unify_int_or_none(step_type, 'step')
        else:
            step_type = Integer()
        return Slice
    elif isinstance(node, gast.ExtSlice):
        return [analyse(dim, env, non_generic) for dim in node.dims]
    elif isinstance(node, gast.NameConstant):
        if node.value is None:
            return env['None']
    elif isinstance(node, gast.Subscript):
        new_type = TypeVariable()
        value_type = prune(analyse(node.value, env, non_generic))
        try:
            slice_type = prune(analyse(node.slice, env, non_generic))
        except PythranTypeError as e:
            raise PythranTypeError(e.msg, node)

        if isinstance(node.slice, gast.ExtSlice):
            nbslice = len(node.slice.dims)
            dtype = TypeVariable()
            try:
                unify(Array(dtype, nbslice), clone(value_type))
            except InferenceError:
                raise PythranTypeError(
                    "Dimension mismatch when slicing `{}`".format(value_type),
                    node)
            return TypeVariable()  # FIXME
        elif isinstance(node.slice, gast.Index):
            # handle tuples in a special way
            isnum = isinstance(node.slice.value, gast.Num)
            if isnum and is_tuple_type(value_type):
                try:
                    unify(prune(prune(value_type.types[0]).types[0])
                          .types[node.slice.value.n],
                          new_type)
                    return new_type
                except IndexError:
                    raise PythranTypeError(
                        "Invalid tuple indexing, "
                        "out-of-bound index `{}` for type `{}`".format(
                            node.slice.value.n,
                            value_type),
                        node)
        try:
            unify(tr(MODULES['operator_']['getitem']),
                  Function([value_type, slice_type], new_type))
        except InferenceError:
            raise PythranTypeError(
                "Invalid subscripting of `{}` by `{}`".format(
                    value_type,
                    slice_type),
                node)
        return new_type
        return new_type
    elif isinstance(node, gast.Attribute):
        from pythran.utils import attr_to_path
        obj, path = attr_to_path(node)
        if obj.signature is typing.Any:
            return TypeVariable()
        else:
            return tr(obj)

    # stmt
    elif isinstance(node, gast.Import):
        for alias in node.names:
            if alias.name not in MODULES:
                raise NotImplementedError("unknown module: %s " % alias.name)
            if alias.asname is None:
                target = alias.name
            else:
                target = alias.asname
            env[target] = tr(MODULES[alias.name])
        return env
    elif isinstance(node, gast.ImportFrom):
        if node.module not in MODULES:
            raise NotImplementedError("unknown module: %s" % node.module)
        for alias in node.names:
            if alias.name not in MODULES[node.module]:
                raise NotImplementedError(
                    "unknown function: %s in %s" % (alias.name, node.module))
            if alias.asname is None:
                target = alias.name
            else:
                target = alias.asname
            env[target] = tr(MODULES[node.module][alias.name])
        return env
    elif isinstance(node, gast.FunctionDef):
        ftypes = []
        for i in range(1 + len(node.args.defaults)):
            old_type = env[node.name]
            new_env = env.copy()
            new_non_generic = non_generic.copy()

            # reset return special variables
            new_env.pop('@ret', None)
            new_env.pop('@gen', None)

            hy = HasYield()
            for stmt in node.body:
                hy.visit(stmt)
            new_env['@gen'] = hy.has_yield

            arg_types = []
            istop = len(node.args.args) - i
            for arg in node.args.args[:istop]:
                arg_type = TypeVariable()
                new_env[arg.id] = arg_type
                new_non_generic.add(arg_type)
                arg_types.append(arg_type)
            for arg, expr in zip(node.args.args[istop:],
                                 node.args.defaults[-i:]):
                arg_type = analyse(expr, new_env, new_non_generic)
                new_env[arg.id] = arg_type

            analyse_body(node.body, new_env, new_non_generic)

            result_type = new_env.get('@ret', NoneType)

            if new_env['@gen']:
                result_type = Generator(result_type)

            ftype = Function(arg_types, result_type)
            ftypes.append(ftype)
        if len(ftypes) == 1:
            ftype = ftypes[0]
            env[node.name] = ftype
        else:
            env[node.name] = MultiType(ftypes)
        return env
    elif isinstance(node, gast.Module):
        analyse_body(node.body, env, non_generic)
        return env
    elif isinstance(node, (gast.Pass, gast.Break, gast.Continue)):
        return env
    elif isinstance(node, gast.Expr):
        analyse(node.value, env, non_generic)
        return env
    elif isinstance(node, gast.Delete):
        for target in node.targets:
            if isinstance(target, gast.Name):
                if target.id in env:
                    del env[target.id]
                else:
                    raise PythranTypeError(
                        "Invalid del: unbound identifier `{}`".format(
                            target.id),
                        node)
            else:
                analyse(target, env, non_generic)
        return env
    elif isinstance(node, gast.Print):
        if node.dest is not None:
            analyse(node.dest, env, non_generic)
        for value in node.values:
            analyse(value, env, non_generic)
        return env
    elif isinstance(node, gast.Assign):
        defn_type = analyse(node.value, env, non_generic)
        for target in node.targets:
            target_type = analyse(target, env, non_generic)
            try:
                unify(target_type, defn_type)
            except InferenceError:
                raise PythranTypeError(
                    "Invalid assignment from type `{}` to type `{}`".format(
                        target_type,
                        defn_type),
                    node)
        return env
    elif isinstance(node, gast.AugAssign):
        # FIMXE: not optimal: evaluates type of node.value twice
        fake_target = deepcopy(node.target)
        fake_target.ctx = gast.Load()
        fake_op = gast.BinOp(fake_target, node.op, node.value)
        gast.copy_location(fake_op, node)
        analyse(fake_op, env, non_generic)

        value_type = analyse(node.value, env, non_generic)
        target_type = analyse(node.target, env, non_generic)

        try:
            unify(target_type, value_type)
        except InferenceError:
            raise PythranTypeError(
                "Invalid update operand for `{}`: `{}` and `{}`".format(
                    symbol_of[type(node.op)],
                    value_type,
                    target_type
                ),
                node
            )
        return env
    elif isinstance(node, gast.Raise):
        return env  # TODO
    elif isinstance(node, gast.Return):
        if env['@gen']:
            return env

        if node.value is None:
            ret_type = NoneType
        else:
            ret_type = analyse(node.value, env, non_generic)
        if '@ret' in env:
            try:
                ret_type = merge_unify(env['@ret'], ret_type)
            except InferenceError:
                raise PythranTypeError(
                    "function may returns with incompatible types "
                    "`{}` and `{}`".format(env['@ret'], ret_type),
                    node
                )

        env['@ret'] = ret_type
        return env
    elif isinstance(node, gast.Yield):
        assert env['@gen']
        assert node.value is not None

        if node.value is None:
            ret_type = NoneType
        else:
            ret_type = analyse(node.value, env, non_generic)
        if '@ret' in env:
            try:
                ret_type = merge_unify(env['@ret'], ret_type)
            except InferenceError:
                raise PythranTypeError(
                    "function may yields incompatible types "
                    "`{}` and `{}`".format(env['@ret'], ret_type),
                    node
                )

        env['@ret'] = ret_type
        return env
    elif isinstance(node, gast.For):
        iter_type = analyse(node.iter, env, non_generic)
        target_type = analyse(node.target, env, non_generic)
        unify(Collection(TypeVariable(), TypeVariable(), TypeVariable(),
                         target_type),
              iter_type)
        analyse_body(node.body, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        return env
    elif isinstance(node, gast.If):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        body_env = env.copy()
        body_non_generic = non_generic.copy()

        if is_test_is_none(node.test):
            none_id = node.test.left.id
            body_env[none_id] = NoneType
        else:
            none_id = None

        analyse_body(node.body, body_env, body_non_generic)

        orelse_env = env.copy()
        orelse_non_generic = non_generic.copy()

        if none_id:
            if is_option_type(env[none_id]):
                orelse_env[none_id] = prune(env[none_id]).types[0]
            else:
                orelse_env[none_id] = TypeVariable()
        analyse_body(node.orelse, orelse_env, orelse_non_generic)

        for var in body_env:
            if var not in env:
                if var in orelse_env:
                    try:
                        new_type = merge_unify(body_env[var], orelse_env[var])
                    except InferenceError:
                        raise PythranTypeError(
                            "Incompatible types from different branches for "
                            "`{}`: `{}` and `{}`".format(
                                var,
                                body_env[var],
                                orelse_env[var]
                            ),
                            node
                        )
                else:
                    new_type = body_env[var]
                env[var] = new_type

        for var in orelse_env:
            if var not in env:
                # may not be unified by the prev loop if a del occured
                if var in body_env:
                    new_type = merge_unify(orelse_env[var], body_env[var])
                else:
                    new_type = orelse_env[var]
                env[var] = new_type

        if none_id:
            try:
                new_type = merge_unify(body_env[none_id], orelse_env[none_id])
            except InferenceError:
                msg = ("Inconsistent types while merging values of `{}` from "
                       "conditional branches: `{}` and `{}`")
                err = msg.format(none_id,
                                 body_env[none_id],
                                 orelse_env[none_id])
                raise PythranTypeError(err, node)
            env[none_id] = new_type

        return env
    elif isinstance(node, gast.While):
        test_type = analyse(node.test, env, non_generic)
        unify(Function([test_type], Bool()),
              tr(MODULES['__builtin__']['bool_']))

        analyse_body(node.body, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        return env
    elif isinstance(node, gast.Try):
        analyse_body(node.body, env, non_generic)
        for handler in node.handlers:
            analyse(handler, env, non_generic)
        analyse_body(node.orelse, env, non_generic)
        analyse_body(node.finalbody, env, non_generic)
        return env
    elif isinstance(node, gast.ExceptHandler):
        if(node.name):
            new_type = ExceptionType
            non_generic.add(new_type)
            if node.name.id in env:
                unify(env[node.name.id], new_type)
            else:
                env[node.name.id] = new_type
        analyse_body(node.body, env, non_generic)
        return env
    elif isinstance(node, gast.Assert):
        if node.msg:
            analyse(node.msg, env, non_generic)
        analyse(node.test, env, non_generic)
        return env
    elif isinstance(node, gast.UnaryOp):
        operand_type = analyse(node.operand, env, non_generic)
        return_type = TypeVariable()
        op_type = analyse(node.op, env, non_generic)
        unify(Function([operand_type], return_type), op_type)
        return return_type
    elif isinstance(node, gast.Invert):
        return MultiType([Function([Bool()], Integer()),
                          Function([Integer()], Integer())])
    elif isinstance(node, gast.Not):
        return tr(MODULES['__builtin__']['bool_'])
    elif isinstance(node, gast.BoolOp):
        op_type = analyse(node.op, env, non_generic)
        value_types = [analyse(value, env, non_generic)
                       for value in node.values]

        for value_type in value_types:
            unify(Function([value_type], Bool()),
                  tr(MODULES['__builtin__']['bool_']))

        return_type = TypeVariable()
        prev_type = value_types[0]
        for value_type in value_types[1:]:
            unify(Function([prev_type, value_type], return_type), op_type)
            prev_type = value_type
        return return_type
    elif isinstance(node, (gast.And, gast.Or)):
        x_type = TypeVariable()
        return MultiType([
            Function([x_type, x_type], x_type),
            Function([TypeVariable(), TypeVariable()], TypeVariable()),
        ])

    raise RuntimeError("Unhandled syntax node {0}".format(type(node)))
Exemplo n.º 30
0
 def tokenize(s):
     '''A simple contextual "parser" for an OpenMP string'''
     # not completely satisfying if there are strings in if expressions
     out = ''
     par_count = 0
     curr_index = 0
     in_reserved_context = False
     in_declare = False
     in_shared = in_private = False
     while curr_index < len(s):
         bounds = []
         if in_declare and is_declare_typename(s, curr_index, bounds):
             start, stop = bounds
             pytypes = parse_pytypes(s[start:stop])
             out += ', '.join(map(pytype_to_ctype, pytypes))
             curr_index = stop
             continue
         m = re.match(r'^([a-zA-Z_]\w*)', s[curr_index:])
         if m:
             word = m.group(0)
             curr_index += len(word)
             if (in_reserved_context
                     or (in_declare and word in declare_keywords)
                     or (par_count == 0 and word in keywords)):
                 out += word
                 in_reserved_context = word in reserved_contex
                 in_declare |= word == 'declare'
                 in_private |= word == 'private'
                 in_shared |= word == 'shared'
             else:
                 out += '{}'
                 self.deps.append(ast.Name(word, ast.Load(), None,
                                           None))
                 isattr = re.match(r'^\s*(\.\s*[a-zA-Z_]\w*)',
                                   s[curr_index:])
                 if isattr:
                     attr = isattr.group(0)
                     curr_index += len(attr)
                     self.deps[-1] = ast.Attribute(
                         self.deps[-1], attr[1:], ast.Load())
                 if in_private:
                     self.private_deps.append(self.deps[-1])
                 if in_shared:
                     self.shared_deps.append(self.deps[-1])
         elif s[curr_index] == '(':
             par_count += 1
             curr_index += 1
             out += '('
         elif s[curr_index] == ')':
             par_count -= 1
             curr_index += 1
             out += ')'
             if par_count == 0:
                 in_reserved_context = False
                 in_shared = in_private = False
         else:
             if s[curr_index] in ',:':
                 in_reserved_context = False
             out += s[curr_index]
             curr_index += 1
     return out