Beispiel #1
0
    def visit_If(self, node):
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        # Note: this information needs to be extracted before the body conversion
        # that happens in the call to generic_visit below, because the conversion
        # generates nodes that lack static analysis annotations.
        need_alias_in_body = self._determine_aliased_symbols(
            body_scope, defined_in, node.body)
        need_alias_in_orelse = self._determine_aliased_symbols(
            orelse_scope, defined_in, node.orelse)

        node = self.generic_visit(node)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        composites = set()
        for s in modified_in_cond:
            if s in live_out and not s.is_composite():
                returned_from_cond.add(s)
            if s.is_composite():
                # Special treatment for compound objects, always return them.
                # This allows special handling within the if_stmt itself.
                # For example, in TensorFlow we need to restore the state of composite
                # symbols to ensure that only effects from the executed branch are seen.
                composites.add(s)

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        basic_created_in_body = tuple(s for s in created_in_body
                                      if not s.is_composite())
        basic_created_in_orelse = tuple(s for s in created_in_orelse
                                        if not s.is_composite())

        # These variables are defined only in a single branch. This is fine in
        # Python so we pass them through. Another backend, e.g. Tensorflow, may need
        # to handle these cases specially or throw an Error.
        possibly_undefined = (set(basic_created_in_body)
                              ^ set(basic_created_in_orelse))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        cond_var_name = self.ctx.namer.new_symbol('cond',
                                                  body_scope.referenced)
        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)
        all_referenced = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      all_referenced)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      all_referenced)

        returned_from_cond = tuple(returned_from_cond)
        composites = tuple(composites)

        if returned_from_cond:
            if len(returned_from_cond) == 1:
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): Replace with None once side_effect_guards is retired.
            returned_from_body = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )
            returned_from_orelse = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )

        cond_assign = self.create_assignment(cond_var_name, node.test)
        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        undefined_assigns = self._create_undefined_assigns(possibly_undefined)
        composite_defs = self._create_state_functions(composites, [],
                                                      state_getter_name,
                                                      state_setter_name)

        basic_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None)
            for symbol in returned_from_cond)
        composite_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None) for symbol in composites)

        cond_expr = self._create_cond_expr(cond_results, cond_var_name,
                                           body_name, orelse_name,
                                           state_getter_name,
                                           state_setter_name,
                                           basic_symbol_names,
                                           composite_symbol_names)

        if_ast = (undefined_assigns + composite_defs + body_def + orelse_def +
                  cond_assign + cond_expr)
        return if_ast
Beispiel #2
0
class Square(Transformation):
    """
    Replaces **2 by a call to numpy.square.

    >>> import gast as ast
    >>> from pythran import passmanager, backend
    >>> node = ast.parse('a**2')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print(pm.dump(backend.Python, node))
    import numpy as __pythran_import_numpy
    __pythran_import_numpy.square(a)
    >>> node = ast.parse('__pythran_import_numpy.power(a,2)')
    >>> pm = passmanager.PassManager("test")
    >>> _, node = pm.apply(Square, node)
    >>> print(pm.dump(backend.Python, node))
    import numpy as __pythran_import_numpy
    __pythran_import_numpy.square(a)
    """

    POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Constant(2, None))
    POWER_PATTERN = ast.Call(
        ast.Attribute(ast.Name(mangle('numpy'), ast.Load(), None, None),
                      'power', ast.Load()),
        [AST_any(), ast.Constant(2, None)], [])

    def __init__(self):
        Transformation.__init__(self)

    def replace(self, value):
        self.update = self.need_import = True
        module_name = ast.Name(mangle('numpy'), ast.Load(), None, None)
        return ast.Call(ast.Attribute(module_name, 'square', ast.Load()),
                        [value], [])

    def visit_Module(self, node):
        self.need_import = False
        self.generic_visit(node)
        if self.need_import:
            import_alias = ast.alias(name='numpy', asname=mangle('numpy'))
            importIt = ast.Import(names=[import_alias])
            node.body.insert(0, importIt)
        return node

    def expand_pow(self, node, n):
        if n == 0:
            return ast.Constant(1, None)
        elif n == 1:
            return node
        else:
            node_square = self.replace(node)
            node_pow = self.expand_pow(node_square, n >> 1)
            if n & 1:
                return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node))
            else:
                return node_pow

    def visit_BinOp(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POW_PATTERN).search(node):
            return self.replace(node.left)
        elif isinstance(node.op, ast.Pow) and isnum(node.right):
            n = node.right.value
            if int(n) == n and n > 0:
                return self.expand_pow(node.left, n)
            else:
                return node
        else:
            return node

    def visit_Call(self, node):
        self.generic_visit(node)
        if ASTMatcher(Square.POWER_PATTERN).search(node):
            return self.replace(node.args[0])
        else:
            return node
Beispiel #3
0
 def sub():
     return ast.BinOp(left=Placeholder(0),
                      op=ast.Pow(),
                      right=ast.Constant(2, None))
Beispiel #4
0
class PyASTGraphsTest(parameterized.TestCase):
    def test_schema(self):
        """Check that the generated schema is reasonable."""
        # Building should succeed
        schema = py_ast_graphs.SCHEMA

        # Should have one graph node for each AST node, plus sorted helpers.
        seq_helpers = [
            "BoolOp_values-seq-helper",
            "Call_args-seq-helper",
            "For_body-seq-helper",
            "FunctionDef_body-seq-helper",
            "If_body-seq-helper",
            "If_orelse-seq-helper",
            "Module_body-seq-helper",
            "While_body-seq-helper",
            "arguments_args-seq-helper",
        ]
        expected_keys = [
            *py_ast_graphs.PY_AST_SPECS.keys(),
            *seq_helpers,
        ]
        self.assertEqual(list(schema.keys()), expected_keys)

        # Check a few elements
        expected_fundef_in_edges = {
            "parent_in",
            "returns_in",
            "returns_missing",
            "body_in",
            "args_in",
        }
        expected_fundef_out_edges = {
            "parent_out",
            "returns_out",
            "body_out_all",
            "body_out_first",
            "body_out_last",
            "args_out",
        }
        self.assertEqual(set(schema["FunctionDef"].in_edges),
                         expected_fundef_in_edges)
        self.assertEqual(set(schema["FunctionDef"].out_edges),
                         expected_fundef_out_edges)

        expected_expr_in_edges = {"parent_in", "value_in"}
        expected_expr_out_edges = {"parent_out", "value_out"}
        self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges)
        self.assertEqual(set(schema["Expr"].out_edges),
                         expected_expr_out_edges)

        expected_seq_helper_in_edges = {
            "parent_in",
            "item_in",
            "next_in",
            "next_missing",
            "prev_in",
            "prev_missing",
        }
        expected_seq_helper_out_edges = {
            "parent_out",
            "item_out",
            "next_out",
            "prev_out",
        }
        for seq_helper in seq_helpers:
            self.assertEqual(set(schema[seq_helper].in_edges),
                             expected_seq_helper_in_edges)
            self.assertEqual(set(schema[seq_helper].out_edges),
                             expected_seq_helper_out_edges)

    def test_ast_graph_conforms_to_schema(self):
        # Some example code using a few different syntactic constructs, to cover
        # a large set of nodes in the schema
        root = gast.parse(
            textwrap.dedent("""\
        def foo(n):
          if n <= 1:
            return 1
          else:
            return foo(n-1) + foo(n-2)

        def bar(m, n) -> int:
          x = n
          for i in range(m):
            if False:
              continue
            x = x + i
          while True:
            break
          return x

        x0 = 1 + 2 - 3 * 4 / 5
        x1 = (1 == 2) and (3 < 4) and (5 > 6)
        x2 = (7 <= 8) and (9 >= 10) or (11 != 12)
        x2 = bar(13, 14 + 15)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Graph should match the schema
        schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA)

    def test_ast_graph_nodes(self):
        """Check node IDs, node types, and forward mapping."""
        root = gast.parse(
            textwrap.dedent("""\
        pass
        def foo(n):
            if n <= 1:
              return 1
        """))

        graph, forward_map = py_ast_graphs.py_ast_to_graph(root)

        # pytype: disable=attribute-error
        self.assertIn("root__Module", graph)
        self.assertEqual(graph["root__Module"].node_type, "Module")
        self.assertEqual(forward_map[id(root)], "root__Module")

        self.assertIn("root_body_1__Module_body-seq-helper", graph)
        self.assertEqual(
            graph["root_body_1__Module_body-seq-helper"].node_type,
            "Module_body-seq-helper")

        self.assertIn("root_body_1_item_body_0_item__If", graph)
        self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type,
                         "If")
        self.assertEqual(forward_map[id(root.body[1].body[0])],
                         "root_body_1_item_body_0_item__If")

        self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph)
        self.assertEqual(
            graph["root_body_1_item_body_0_item_test_left__Name"].node_type,
            "Name")
        self.assertEqual(forward_map[id(root.body[1].body[0].test.left)],
                         "root_body_1_item_body_0_item_test_left__Name")
        # pytype: enable=attribute-error

    def test_ast_graph_unique_field_edges(self):
        """Test that edges for unique fields are correct."""
        root = gast.parse("print(1)")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Expr"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item_value__Call"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Call"].out_edges["parent_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Expr"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

    def test_ast_graph_optional_field_edges(self):
        """Test that edges for optional fields are correct."""
        root = gast.parse("return 1\nreturn")
        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        self.assertEqual(
            graph["root_body_0_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId(
                        "root_body_0_item_value__Constant"),
                    in_edge=graph_types.InEdgeType("parent_in"))
            ])

        self.assertEqual(
            graph["root_body_0_item_value__Constant"].out_edges["parent_out"],
            [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_0_item__Return"),
                    in_edge=graph_types.InEdgeType("value_in"))
            ])

        self.assertEqual(
            graph["root_body_1_item__Return"].out_edges["value_out"], [
                graph_types.InputTaggedNode(
                    node_id=graph_types.NodeId("root_body_1_item__Return"),
                    in_edge=graph_types.InEdgeType("value_missing"))
            ])

    def test_ast_graph_sequence_field_edges(self):
        """Test that edges for sequence fields are correct.

    Note that sequence fields produce connections between three nodes: the
    parent, the helper node, and the child.
    """
        root = gast.parse(
            textwrap.dedent("""\
        print(1)
        print(2)
        print(3)
        print(4)
        print(5)
        print(6)
        """))

        graph, _ = py_ast_graphs.py_ast_to_graph(root)

        # Child edges from the parent node
        node = graph["root__Module"]
        self.assertLen(node.out_edges["body_out_all"], 6)
        self.assertEqual(node.out_edges["body_out_first"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["body_out_last"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_5__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])

        # Edges from the sequence helper
        node = graph["root_body_0__Module_body-seq-helper"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root__Module"),
                in_edge=graph_types.InEdgeType("body_in"))
        ])
        self.assertEqual(node.out_edges["item_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId("root_body_0_item__Expr"),
                in_edge=graph_types.InEdgeType("parent_in"))
        ])
        self.assertEqual(node.out_edges["prev_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_missing"))
        ])
        self.assertEqual(node.out_edges["next_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_1__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("prev_in"))
        ])

        # Parent edge of the item
        node = graph["root_body_0_item__Expr"]
        self.assertEqual(node.out_edges["parent_out"], [
            graph_types.InputTaggedNode(
                node_id=graph_types.NodeId(
                    "root_body_0__Module_body-seq-helper"),
                in_edge=graph_types.InEdgeType("item_in"))
        ])

    @parameterized.named_parameters(
        {
            "testcase_name": "unexpected_type",
            "ast": gast.Subscript(value=None, slice=None, ctx=None),
            "expected_error": "Unknown AST node type 'Subscript'",
        }, {
            "testcase_name":
            "too_many_unique",
            "ast":
            gast.Assign(targets=[
                gast.Name("foo", gast.Store(), None, None),
                gast.Name("bar", gast.Store(), None, None)
            ],
                        value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 2",
        }, {
            "testcase_name":
            "missing_unique",
            "ast":
            gast.Assign(targets=[], value=gast.Constant(True, None)),
            "expected_error":
            "Expected 1 child for field 'targets' of node .*; got 0",
        }, {
            "testcase_name":
            "too_many_optional",
            "ast":
            gast.Return(value=[
                gast.Name("foo", gast.Load(), None, None),
                gast.Name("bar", gast.Load(), None, None)
            ]),
            "expected_error":
            "Expected at most 1 child for field 'value' of node .*; got 2",
        })
    def test_invalid_graphs(self, ast, expected_error):
        with self.assertRaisesRegex(ValueError, expected_error):
            py_ast_graphs.py_ast_to_graph(ast)
Beispiel #5
0
    def visit_BinOp(self, node):
        if not isinstance(node.op, ast.Mod):
            return self.generic_visit(node)

        # check that right is a name defined once outside of loop
        # TODO: handle expression instead of names
        if not isinstance(node.right, ast.Name):
            return self.generic_visit(node)

        right_def = self.single_def(node.right)
        if not right_def:
            return self.generic_visit(node)

        if self.range_values[node.right.id].low < 0:
            return self.generic_visit(node)

        # same for lhs
        if not isinstance(node.left, ast.Name):
            return self.generic_visit(node)

        head = self.single_def(node.left)
        if not head:
            return self.generic_visit(node)

        # check lhs is the actual index of a loop
        loop = self.ancestors[head][-1]

        if not isinstance(loop, ast.For):
            return self.generic_visit(node)

        if not isinstance(loop.iter, ast.Call):
            return self.generic_visit(node)

        # make sure rhs is defined out of the loop
        if loop in self.ancestors[right_def]:
            return self.generic_visit(node)

        # gather range informations
        range_ = None
        for alias in self.aliases[loop.iter.func]:
            if alias is MODULES['builtins']['range']:
                range_ = alias
            else:
                break

        if range_ is None:
            return self.generic_visit(node)

        # everything is setup for the transformation!
        new_id = node.left.id + '_m'
        i = 0
        while new_id in self.identifiers:
            new_id = '{}_m{}'.format(node.left.id, i)
            i += 1

        rargs = range_.args.args
        lower = rargs[0] if len(rargs) > 1 else ast.Constant(0, None)
        header = ast.Assign([ast.Name(new_id, ast.Store(), None, None)],
                            ast.BinOp(
                                ast.BinOp(deepcopy(lower), ast.Sub(),
                                          ast.Constant(1, None)), ast.Mod(),
                                deepcopy(node.right)), None)
        incr = ast.BinOp(ast.Name(new_id, ast.Load(), None, None), ast.Add(),
                         ast.Constant(1, None))
        step = ast.Assign([ast.Name(new_id, ast.Store(), None, None)],
                          ast.IfExp(
                              ast.Compare(incr, [ast.Eq()],
                                          [deepcopy(node.right)]),
                              ast.Constant(0, None), deepcopy(incr)), None)

        self.loops_mod.setdefault(loop, []).append((header, step))
        self.update = True
        return ast.Name(new_id, ast.Load(), None, None)
Beispiel #6
0
 def test_iter_child_nodes(self):
     tree = gast.UnaryOp(gast.USub(), gast.Constant(value=1, kind=None))
     self.assertEqual(len(list(gast.iter_fields(tree))), 2)
Beispiel #7
0
    def visit_If(self, node):
        if node.test not in self.static_expressions:
            return self.generic_visit(node)

        imported_ids = self.gather(ImportedIds, node)

        assigned_ids_left = self.escaping_ids(node, node.body)
        assigned_ids_right = self.escaping_ids(node, node.orelse)
        assigned_ids_both = assigned_ids_left.union(assigned_ids_right)

        imported_ids.update(i for i in assigned_ids_left
                            if i not in assigned_ids_right)
        imported_ids.update(i for i in assigned_ids_right
                            if i not in assigned_ids_left)
        imported_ids = sorted(imported_ids)

        assigned_ids = sorted(assigned_ids_both)

        fbody = self.make_fake(node.body)
        true_has_return = self.gather(HasReturn, fbody)
        true_has_break = self.gather(HasBreak, fbody)
        true_has_cont = self.gather(HasContinue, fbody)

        felse = self.make_fake(node.orelse)
        false_has_return = self.gather(HasReturn, felse)
        false_has_break = self.gather(HasBreak, felse)
        false_has_cont = self.gather(HasContinue, felse)

        has_return = true_has_return or false_has_return
        has_break = true_has_break or false_has_break
        has_cont = true_has_cont or false_has_cont

        self.generic_visit(node)

        func_true = outline(self.true_name(), imported_ids, assigned_ids,
                            node.body, has_return, has_break, has_cont)
        func_false = outline(self.false_name(), imported_ids, assigned_ids,
                             node.orelse, has_return, has_break, has_cont)
        self.new_functions.extend((func_true, func_false))

        actual_call = self.make_dispatcher(node.test, func_true, func_false,
                                           imported_ids)

        # variable modified within the static_if
        expected_return = [
            ast.Name(ii, ast.Store(), None, None) for ii in assigned_ids
        ]

        self.update = True

        # name for various variables resulting from the static_if
        n = len(self.new_functions)
        status_n = "$status{}".format(n)
        return_n = "$return{}".format(n)
        cont_n = "$cont{}".format(n)

        if has_return:
            cfg = self.cfgs[-1]
            always_return = all(
                isinstance(x, (ast.Return, ast.Yield)) for x in cfg[node])
            always_return &= true_has_return and false_has_return

            fast_return = [
                ast.Name(status_n, ast.Store(), None, None),
                ast.Name(return_n, ast.Store(), None, None),
                ast.Name(cont_n, ast.Store(), None, None)
            ]

            if always_return:
                return [
                    ast.Assign([ast.Tuple(fast_return, ast.Store())],
                               actual_call),
                    ast.Return(ast.Name(return_n, ast.Load(), None, None))
                ]
            else:
                cont_ass = self.make_control_flow_handlers(
                    cont_n, status_n, expected_return, has_cont, has_break)

                cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None),
                                   [ast.Eq()], [ast.Constant(EARLY_RET, None)])
                return [
                    ast.Assign([ast.Tuple(fast_return, ast.Store())],
                               actual_call),
                    ast.If(cmpr, [
                        ast.Return(ast.Name(return_n, ast.Load(), None, None))
                    ], cont_ass)
                ]
        elif has_break or has_cont:
            cont_ass = self.make_control_flow_handlers(cont_n, status_n,
                                                       expected_return,
                                                       has_cont, has_break)

            fast_return = [
                ast.Name(status_n, ast.Store(), None, None),
                ast.Name(cont_n, ast.Store(), None, None)
            ]
            return [
                ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call)
            ] + cont_ass
        elif expected_return:
            return ast.Assign([ast.Tuple(expected_return, ast.Store())],
                              actual_call)
        else:
            return ast.Expr(actual_call)
Beispiel #8
0
    def visit_While(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars, reserved_symbols,
         possibly_undefs) = self._get_loop_vars(
             node,
             anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        basic_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None)
            for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None)
            for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        state_functions
        def body_name(loop_vars):
          body
          return loop_vars,
        def test_name(loop_vars):
          return test
        loop_vars_ast_tuple = ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        state_functions
        def body_name():
          body
          return ()
        def test_name():
          return test
        ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)
        return undefined_assigns + node
Beispiel #9
0
    def visit_For(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars,
         reserved_symbols, possibly_undefs) = self._get_loop_vars(
             node,
             (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified
              | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified))
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)
        body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test_name = self.ctx.namer.new_symbol(
                'extra_test', reserved_symbols)
            template = """
        def extra_test_name(loop_vars):
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                extra_test_expr=extra_test)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # Workaround for PEP-3113
        # iterates_var holds a single variable with the iterates, which may be a
        # tuple.
        iterates_var_name = self.ctx.namer.new_symbol('iterates',
                                                      reserved_symbols)
        template = """
      iterates = iterates_var_name
    """
        iterate_expansion = templates.replace(
            template,
            iterates=node.target,
            iterates_var_name=iterates_var_name)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        basic_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None)
            for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Constant(str(symbol), kind=None)
            for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name, loop_vars):
          iterate_expansion
          body
          return loop_vars,
        extra_test_function
        loop_vars_ast_tuple = ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name):
          iterate_expansion
          body
          return ()
        extra_test_function
        ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
 def fill(self, hole, rng):
     i = rng.randint(0, 100)
     return ASTWithHoles(1, [], lambda: gast.Constant(value=i, kind=None))
 def fill(self, hole, rng):
     value = rng.choice([True, False])
     return ASTWithHoles(1, [],
                         lambda: gast.Constant(value=value, kind=None))
Beispiel #12
0
def negate(node):
    if isinstance(node, ast.Name):
        # Not type info, could be anything :(
        raise UnsupportedExpression()

    if isinstance(node, ast.UnaryOp):
        # !~x <> ~x == 0 <> x == ~0 <> x == -1
        if isinstance(node.op, ast.Invert):
            return ast.Compare(node.operand, [ast.Eq()],
                               [ast.Constant(-1, None)])
        # !!x <> x
        if isinstance(node.op, ast.Not):
            return node.operand
        # !+x <> +x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.UAdd):
            return node.operand
        # !-x <> -x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.USub):
            return node.operand

    if isinstance(node, ast.BoolOp):
        new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values]
        # !(x or y) <> !x and !y
        if isinstance(node.op, ast.Or):
            return ast.BoolOp(ast.And(), new_values)
        # !(x and y) <> !x or !y
        if isinstance(node.op, ast.And):
            return ast.BoolOp(ast.Or(), new_values)

    if isinstance(node, ast.Compare):
        cmps = [
            ast.Compare(x, [negate(o)], [y])
            for x, o, y in zip([node.left] + node.comparators[:-1], node.ops,
                               node.comparators)
        ]
        if len(cmps) == 1:
            return cmps[0]
        return ast.BoolOp(ast.Or(), cmps)

    if isinstance(node, ast.Eq):
        return ast.NotEq()
    if isinstance(node, ast.NotEq):
        return ast.Eq()
    if isinstance(node, ast.Gt):
        return ast.LtE()
    if isinstance(node, ast.GtE):
        return ast.Lt()
    if isinstance(node, ast.Lt):
        return ast.GtE()
    if isinstance(node, ast.LtE):
        return ast.Gt()
    if isinstance(node, ast.In):
        return ast.NotIn()
    if isinstance(node, ast.NotIn):
        return ast.In()

    if isinstance(node, ast.Attribute):
        if node.attr == 'False':
            return ast.Constant(True, None)
        if node.attr == 'True':
            return ast.Constant(False, None)

    raise UnsupportedExpression()
Beispiel #13
0
 def visit_Bytes(self, node):
     new_node = gast.Constant(
         node.s,
         None,
     )
     return gast.copy_location(new_node, node)
Beispiel #14
0
 def test_increment_lineno(self):
     tree = gast.Constant(value=1, kind=None)
     tree.lineno = 1
     gast.increment_lineno(tree)
     self.assertEqual(tree.lineno, 2)
Beispiel #15
0
    def test_buildable(self, template):
        """Test that each template can be built when given acceptable arguments."""
        rng = np.random.RandomState(1234)

        # Construct a hole that this template can always fill.
        hole = top_down_refinement.Hole(
            template.fills_type,
            python_numbers_control_flow.ASTHoleMetadata(names_in_scope=("a", ),
                                                        inside_function=True,
                                                        inside_loop=True,
                                                        op_depth=0))
        self.assertTrue(template.can_fill(hole))

        # Make sure we can build this object with no errors.
        filler = template.fill(hole, rng)
        dummy_values = {
            python_numbers_control_flow.ASTHoleType.NUMBER:
            (lambda: gast.Constant(value=1, kind=None)),
            python_numbers_control_flow.ASTHoleType.BOOL:
            (lambda: gast.Constant(value=True, kind=None)),
            python_numbers_control_flow.ASTHoleType.STMT:
            gast.Pass,
            python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []),
            python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY:
            (lambda: [gast.Pass()]),
            python_numbers_control_flow.ASTHoleType.BLOCK:
            (lambda: [gast.Pass()]),
        }
        hole_values = [dummy_values[h.hole_type]() for h in filler.holes]
        value = filler.build(*hole_values)

        # Check the type of the value that was built.
        if template.fills_type in (
                python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
                python_numbers_control_flow.ASTHoleType.BLOCK):
            self.assertTrue(value)
            for item in value:
                self.assertIsInstance(item, gast.stmt)
        elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS:
            for item in value:
                self.assertIsInstance(item, gast.stmt)
        elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT:
            self.assertIsInstance(value, gast.stmt)
        elif template.fills_type in (
                python_numbers_control_flow.ASTHoleType.NUMBER,
                python_numbers_control_flow.ASTHoleType.BOOL):
            self.assertIsInstance(value, gast.expr)
        else:
            raise NotImplementedError(
                f"Unexpected fill type {template.fills_type}; "
                "please update this test.")

        # Check that cost reflects number of AST nodes.
        total_cost = 0
        if isinstance(value, gast.AST):
            for _ in gast.walk(value):
                total_cost += 1
        else:
            for item in value:
                for _ in gast.walk(item):
                    total_cost += 1

        self.assertEqual(template.required_cost, total_cost)

        cost_without_holes = total_cost - sum(
            python_numbers_control_flow.ALL_COSTS[h.hole_type]
            for h in filler.holes)

        self.assertEqual(filler.cost, cost_without_holes)

        # Check determinism
        for _ in range(20):
            rng = np.random.RandomState(1234)
            redo_value = template.fill(hole, rng).build(*hole_values)
            if isinstance(value, list):
                self.assertEqual([gast.dump(v) for v in value],
                                 [gast.dump(v) for v in redo_value])
            else:
                self.assertEqual(gast.dump(value), gast.dump(redo_value))
Beispiel #16
0
 def make_fake(stmts):
     return ast.If(ast.Constant(0, None), stmts, [])
Beispiel #17
0
    def visit_For(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)

        loop_vars, undefined, _ = self._get_block_vars(
            node, body_scope.modified | iter_scope.modified)

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        reserved = body_scope.referenced | iter_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
        state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
        state_functions = self._create_state_functions(loop_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        opts = self._create_loop_options(node)
        opts.keys.append(gast.Constant('iterate_names', kind=None))
        opts.values.append(
            gast.Constant(parser.unparse(node.target,
                                         include_encoding_marker=False),
                          kind=None))

        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
            extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved)
            template = """
        def extra_test_name():
          nonlocal_declarations
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_expr=extra_test,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                nonlocal_declarations=nonlocal_declarations)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # iterate_arg_name holds a single arg with the iterates, which may be a
        # tuple.
        iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
        template = """
      iterates = iterate_arg_name
    """
        iterate_expansion = templates.replace(
            template, iterate_arg_name=iterate_arg_name, iterates=node.target)

        template = """
      state_functions
      def body_name(iterate_arg_name):
        nonlocal_declarations
        iterate_expansion
        body
      extra_test_function
      undefined_assigns
      ag__.for_stmt(
          iterated,
          extra_test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        return templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved),
            extra_test_function=extra_test_function,
            extra_test_name=extra_test_name,
            iterate_arg_name=iterate_arg_name,
            iterate_expansion=iterate_expansion,
            iterated=node.iter,
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in loop_vars),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            undefined_assigns=undefined_assigns)
Beispiel #18
0
 def test_iter_fields(self):
     tree = gast.Constant(value=1, kind=None)
     self.assertEqual({name
                       for name, _ in gast.iter_fields(tree)},
                      {'value', 'kind'})