def visit_Break(self, node):
        loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
            node, variable_name, loop_node_index)

        # 3. Add 'if not V' for stmts in ancestor blocks between the first one
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index,
                              variable_name)

        # 4. For 'break' add break into condition of the loop.
        assign_false_node = create_fill_constant_node(variable_name, False)
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

        cond_var_node = gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=variable_name,
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None))

        if isinstance(loop_node, gast.While):
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node])
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
            for_to_while = ForToWhileTransformer(parent_node, loop_node,
                                                 cond_var_node)
            for_to_while.transform()
Пример #2
0
    def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
                                     max_return_length, parent_node_of_return):

        assert max_return_length >= 0, "Input illegal max_return_length"
        i = index_in_list(stmt_list, return_node)
        if i == -1:
            return False

        assign_nodes = []
        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = True'
            node_str = "{} = _jst.create_bool_as_type({}, True)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())

            assign_true_node = gast.parse(node_str).body[0]
            assign_nodes.append(assign_true_node)

        cur_func_node = self.function_def[-1]
        return_length = get_return_size(return_node)
        if return_length < max_return_length:
            # In this case we should append RETURN_NO_VALUE placeholder
            #
            # max_return_length must be >= 1 here because return_length will be
            # 0 at least.
            if self.return_value_name[cur_func_node] is None:
                self.return_value_name[cur_func_node] = unique_name.generate(
                    RETURN_VALUE_PREFIX)

            no_value_names = [
                unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
                for j in range(max_return_length - return_length)
            ]
            self.return_no_value_name[cur_func_node].extend(no_value_names)

            # Handle tuple/non-tuple case
            if max_return_length == 1:
                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Name(id=no_value_names[0],
                                                ctx=gast.Load(),
                                                annotation=None,
                                                type_comment=None)))
            else:
                # max_return_length > 1 which means we should assign tuple
                fill_tuple = [
                    gast.Name(id=n,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for n in no_value_names
                ]
                if return_node.value is not None:
                    if isinstance(return_node.value, gast.Tuple):
                        fill_tuple[:0] = return_node.value.elts
                    else:
                        fill_tuple.insert(0, return_node.value)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Tuple(elts=fill_tuple,
                                                 ctx=gast.Load())))
        else:
            # In this case we should NOT append RETURN_NO_VALUE placeholder
            if return_node.value is not None:
                cur_func_node = self.function_def[-1]
                if self.return_value_name[cur_func_node] is None:
                    self.return_value_name[
                        cur_func_node] = unique_name.generate(
                            RETURN_VALUE_PREFIX)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=return_node.value))

        stmt_list[i:] = assign_nodes
        return True
Пример #3
0
    def visit_Return(self, node):
        cur_func_node = self.function_def[-1]
        return_name = unique_name.generate(RETURN_PREFIX)
        self.return_name[cur_func_node].append(return_name)
        max_return_length = self.pre_analysis.get_func_max_return_length(
            cur_func_node)
        parent_node_of_return = self.ancestor_nodes[-2]

        for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
            ancestor = self.ancestor_nodes[ancestor_index]
            cur_node = self.ancestor_nodes[ancestor_index + 1]
            if hasattr(
                    ancestor,
                    "body") and index_in_list(ancestor.body, cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(ancestor.body, cur_node,
                                                      return_name,
                                                      max_return_length,
                                                      parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.body, cur_node, return_name,
                    parent_node_of_return)
            elif hasattr(ancestor, "orelse") and index_in_list(
                    ancestor.orelse, cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(ancestor.orelse,
                                                      cur_node, return_name,
                                                      max_return_length,
                                                      parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.orelse, cur_node, return_name,
                    parent_node_of_return)

            # If return node in while loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.While):
                cond_var_node = gast.UnaryOp(op=gast.Not(),
                                             operand=gast.Name(
                                                 id=return_name,
                                                 ctx=gast.Load(),
                                                 annotation=None,
                                                 type_comment=None))
                ancestor.test = gast.BoolOp(
                    op=gast.And(), values=[ancestor.test, cond_var_node])
                continue

            # If return node in for loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.For):
                cond_var_node = gast.UnaryOp(op=gast.Not(),
                                             operand=gast.Name(
                                                 id=return_name,
                                                 ctx=gast.Load(),
                                                 annotation=None,
                                                 type_comment=None))
                parent_node = self.ancestor_nodes[ancestor_index - 1]
                for_to_while = ForToWhileTransformer(parent_node, ancestor,
                                                     cond_var_node)
                new_stmts = for_to_while.transform()
                while_node = new_stmts[-1]
                self.ancestor_nodes[ancestor_index] = while_node

            if ancestor == cur_func_node:
                break
Пример #4
0
    def visit_FunctionDef(self, node):
        self.function_def.append(node)
        self.return_value_name[node] = None
        self.return_name[node] = []
        self.return_no_value_name[node] = []

        self.pre_analysis = ReturnAnalysisVisitor(node)
        max_return_length = self.pre_analysis.get_func_max_return_length(node)
        while self.pre_analysis.get_func_return_count(node) > 1:
            self.generic_visit(node)
            self.pre_analysis = ReturnAnalysisVisitor(node)

        if max_return_length == 0:
            self.function_def.pop()
            return node

        # Prepend initialization of final return and append final return statement
        value_name = self.return_value_name[node]
        if value_name is not None:
            node.body.append(
                gast.Return(value=gast.Name(id=value_name,
                                            ctx=gast.Load(),
                                            annotation=None,
                                            type_comment=None)))
            init_names = [
                unique_name.generate(RETURN_VALUE_INIT_NAME)
                for i in range(max_return_length)
            ]
            assign_zero_nodes = [
                create_fill_constant_node(iname, 0.0) for iname in init_names
            ]
            if len(init_names) == 1:
                return_value_nodes = gast.Name(id=init_names[0],
                                               ctx=gast.Load(),
                                               annotation=None,
                                               type_comment=None)
            else:
                # We need to initialize return value as a tuple because control
                # flow requires some inputs or outputs have same structure
                return_value_nodes = gast.Tuple(elts=[
                    gast.Name(id=iname,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for iname in init_names
                ],
                                                ctx=gast.Load())
            assign_return_value_node = gast.Assign(targets=[
                gast.Name(id=value_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                                   value=return_value_nodes)
            node.body.insert(0, assign_return_value_node)
            node.body[:0] = assign_zero_nodes

        # Prepend no value placeholders
        for name in self.return_no_value_name[node]:
            assign_no_value_node = create_fill_constant_node(
                name, RETURN_NO_VALUE_MAGIC_NUM)
            node.body.insert(0, assign_no_value_node)

        self.function_def.pop()
        return node
Пример #5
0
 def setUp(self):
     self.source = """
       def test_fn(x, y):
         a = 1
         x = y + a
         if x > y:
            z = x * x
            z = z + a
         else:
            z = y * y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(),
             gast.Store(),
             gast.Load(),
             gast.Load(),
             gast.Load()
         ],
         'a': [gast.Store(), gast.Load(),
               gast.Load()],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Store(),
             gast.Load(),
         ]
     }
Пример #6
0
 def setUp(self):
     self.source = """
       def test_fn(x):
         return x+1
     """
     self.all_name_ids = {'x': [gast.Param(), gast.Load()]}
Пример #7
0
    def get_while_stmt_nodes(self, node):
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        new_stmts = []

        # Python can create variable in loop and use it out of loop, E.g.
        #
        # while x < 10:
        #     x += 1
        #     y = x
        # z = y
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=node.test)],
            decorator_list=[],
            returns=None,
            type_comment=None)

        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        new_body = node.body
        new_body.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(WHILE_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=new_body,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)
        return new_stmts
Пример #8
0
    def get_for_stmt_nodes(self, node):
        # TODO: consider for - else in python

        # 1. get key statements for different cases
        # NOTE 1: three key statements:
        #   1). init_stmts: list[node], prepare nodes of for loop, may not only one
        #   2). cond_stmt: node, condition node to judge whether continue loop
        #   3). body_stmts: list[node], updated loop body, sometimes we should change
        #       the original statement in body, not just append new statement
        #
        # NOTE 2: The following `for` statements will be transformed to `while` statements:
        #   1). for x in range(*)
        #   2). for x in iter_var
        #   3). for i, x in enumerate(*)

        current_for_node_parser = ForNodeVisitor(node)
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
            return [node]
        init_stmts, cond_stmt, body_stmts = stmts_tuple

        # 2. get original loop vars
        loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
            node)
        # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
        # we need append new loop var & remove useless loop var
        #   1. for x in var -> x is no need
        #   2. for i, x in enumerate(var) -> x is no need
        if current_for_node_parser.is_for_iter(
        ) or current_for_node_parser.is_for_enumerate_iter():
            iter_var_name = current_for_node_parser.iter_var_name
            iter_idx_name = current_for_node_parser.iter_idx_name
            loop_var_names.add(iter_idx_name)
            if iter_var_name not in create_var_names:
                loop_var_names.remove(iter_var_name)

        # 3. prepare result statement list
        new_stmts = []
        # Python can create variable in loop and use it out of loop, E.g.
        #
        # for x in range(10):
        #     y += x
        # print(x) # x = 10
        #
        # We need to create static variable for those variables
        for name in create_var_names:
            if "." not in name:
                new_stmts.append(create_static_variable_gast_node(name))

        # 4. append init statements
        new_stmts.extend(init_stmts)

        # 5. create & append condition function node
        condition_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_CONDITION_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=[gast.Return(value=cond_stmt)],
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(condition_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(condition_func_node)

        # 6. create & append loop body function node
        # append return values for loop body
        body_stmts.append(
            gast.Return(value=generate_name_node(
                loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
        body_func_node = gast.FunctionDef(
            name=unique_name.generate(FOR_BODY_PREFIX),
            args=gast.arguments(
                args=[
                    gast.Name(
                        id=name,
                        ctx=gast.Param(),
                        annotation=None,
                        type_comment=None) for name in loop_var_names
                ],
                posonlyargs=[],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=None,
                kwarg=None,
                defaults=[]),
            body=body_stmts,
            decorator_list=[],
            returns=None,
            type_comment=None)
        for name in loop_var_names:
            if "." in name:
                rename_transformer = RenameTransformer(body_func_node)
                rename_transformer.rename(
                    name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
        new_stmts.append(body_func_node)

        # 7. create & append while loop node
        while_loop_nodes = create_while_nodes(
            condition_func_node.name, body_func_node.name, loop_var_names)
        new_stmts.extend(while_loop_nodes)

        return new_stmts
Пример #9
0
    def get_loop_var_names(self, node):
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
        loop_var_names = set()
        create_var_names = set()
        read_context = {type(gast.Load()), type(gast.AugLoad())}

        in_loop_vars_list = self.in_loop_vars[node]

        # get dict `var_name_to_ctxs`
        var_name_to_ctxs = defaultdict(list)
        for var_node in in_loop_vars_list:
            var_name_to_ctxs[self._var_node_to_name(var_node)].append(
                var_node.ctx)

        in_loop_vars = set(in_loop_vars_list)
        in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node)
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)

        before_loop_body_vars = self.before_loop_body_vars[node]
        before_loop_body_vars = self._remove_unnecessary_vars(
            before_loop_body_vars, node)
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)

        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
        after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node)
        after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
                                                        read_context)
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)

        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
                loop_var_names.add(name)

            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
                loop_var_names.add(name)
                create_var_names.add(name)
            else:
                # If a variable is used and created in loop, but used before created,
                # it should be in loop_var and we should create it.

                # For example, `var_a` should be in loop_var and we should create it.
                #
                #   res = 0
                #   for i, x in enumerate(x_array):
                #       if i > 2:
                #           x = func1(var_a)
                #       var_a = func2(x)
                #

                is_created = False
                for ctx in var_name_to_ctxs[name]:
                    if isinstance(ctx, gast.Store):
                        is_created = True

                if isinstance(var_name_to_ctxs[name][0],
                              gast.Load) and is_created:
                    loop_var_names.add(name)
                    create_var_names.add(name)

        return loop_var_names, create_var_names