Beispiel #1
0
 def setUp(self):
     self.source = """
       def test_fn(x, y):
         z = 1
         if x > y:
            z = x * x
            z = z + y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Load(),
         ]
     }
    def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
                                               return_name,
                                               parent_node_of_return):
        i = index_in_list(stmt_list, node)
        if i < 0 or i >= len(stmt_list):
            return False
        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

        if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(),
                                            operand=gast.Name(
                                                id=return_name,
                                                ctx=gast.Store(),
                                                annotation=None,
                                                type_comment=None)),
                          body=stmt_list[i + 1:],
                          orelse=[])

        stmt_list[i + 1:] = [if_stmt]

        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = False'
            node_str = "{} = _jst.create_bool_as_type({}, False)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())
            assign_false_node = gast.parse(node_str).body[0]

            stmt_list[i:i] = [assign_false_node]
        return True
Beispiel #3
0
 def setUp(self):
     self.source = """
       def test_fn(x, y):
         a = 1
         x = y + a
         if x > y:
            z = x * x
            z = z + a
         else:
            z = y * y
         return z
     """
     self.all_name_ids = {
         'x': [
             gast.Param(),
             gast.Store(),
             gast.Load(),
             gast.Load(),
             gast.Load()
         ],
         'a': [gast.Store(), gast.Load(),
               gast.Load()],
         'y': [
             gast.Param(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
             gast.Load(),
         ],
         'z': [
             gast.Store(),
             gast.Load(),
             gast.Store(),
             gast.Store(),
             gast.Load(),
         ]
     }
Beispiel #4
0
    def visit_Name(self, node):
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
        if node.id in self.blacklist_names:
            self.generic_visit(node)
            return

        self.current_seen_vars.add(node)
        write_context = {
            type(gast.Store()), type(gast.AugStore()), type(gast.Del())
        }
        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].append(node)
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
        if self.in_condition:
            self.condition_vars[loop_node].add(node)
        self.generic_visit(node)
    def visit_If(self, node):
        """
        For nested `if/else`, the created vars are not always visible for parent node.
        In addition, the vars created in `if.body` are not visible for `if.orelse`.

        Case 1:
            x = 1
            if m > 1:
                res = new_tensor
            res = res + 1   # Error, `res` is not visible here.

        Case 2:
            if x_tensor > 0:
                res = new_tensor
            else:
                res = res + 1   # Error, `res` is not visible here.

        In above two cases, we should consider to manage the scope of vars to parsing
        the arguments and returned vars correctly.
        """
        if not self._in_range or not self.end_node:
            self.generic_visit(node)
            return
        else:
            before_if_name_ids = copy.deepcopy(self.name_ids)
            body_name_ids = self._visit_child(node.body)
            # If traversal process stops early in `if.body`, return the currently seen name_ids.
            if not self._in_range:
                self._update_name_ids(before_if_name_ids)
            else:
                else_name_ids = self._visit_child(node.orelse)
                # If traversal process stops early in `if.orelse`, return the currently seen name_ids.
                if not self._in_range:
                    self._update_name_ids(before_if_name_ids)
                else:
                    # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
                    # into name_ids.
                    new_name_ids = self._find_new_name_ids(
                        body_name_ids, else_name_ids)
                    for new_name_id in new_name_ids:
                        before_if_name_ids[new_name_id].append(gast.Store())

                    self.name_ids = before_if_name_ids
    def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
                                               break_continue_name):
        i = index_in_list(stmt_list, node)
        if i == -1:
            return False

        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

        if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(),
                                            operand=gast.Name(
                                                id=break_continue_name,
                                                ctx=gast.Store(),
                                                annotation=None,
                                                type_comment=None)),
                          body=stmt_list[i + 1:],
                          orelse=[])
        stmt_list[i + 1:] = []
        stmt_list.append(if_stmt)
        return True
    def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
                                     max_return_length, parent_node_of_return):

        assert max_return_length >= 0, "Input illegal max_return_length"
        i = index_in_list(stmt_list, return_node)
        if i == -1:
            return False

        assign_nodes = []
        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = True'
            node_str = "{} = _jst.create_bool_as_type({}, True)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())

            assign_true_node = gast.parse(node_str).body[0]
            assign_nodes.append(assign_true_node)

        cur_func_node = self.function_def[-1]
        return_length = get_return_size(return_node)
        if return_length < max_return_length:
            # In this case we should append RETURN_NO_VALUE placeholder
            #
            # max_return_length must be >= 1 here because return_length will be
            # 0 at least.
            if self.return_value_name[cur_func_node] is None:
                self.return_value_name[cur_func_node] = unique_name.generate(
                    RETURN_VALUE_PREFIX)

            no_value_names = [
                unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
                for j in range(max_return_length - return_length)
            ]
            self.return_no_value_name[cur_func_node].extend(no_value_names)

            # Handle tuple/non-tuple case
            if max_return_length == 1:
                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Name(id=no_value_names[0],
                                                ctx=gast.Load(),
                                                annotation=None,
                                                type_comment=None)))
            else:
                # max_return_length > 1 which means we should assign tuple
                fill_tuple = [
                    gast.Name(id=n,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for n in no_value_names
                ]
                if return_node.value is not None:
                    if isinstance(return_node.value, gast.Tuple):
                        fill_tuple[:0] = return_node.value.elts
                    else:
                        fill_tuple.insert(0, return_node.value)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Tuple(elts=fill_tuple,
                                                 ctx=gast.Load())))
        else:
            # In this case we should NOT append RETURN_NO_VALUE placeholder
            if return_node.value is not None:
                cur_func_node = self.function_def[-1]
                if self.return_value_name[cur_func_node] is None:
                    self.return_value_name[
                        cur_func_node] = unique_name.generate(
                            RETURN_VALUE_PREFIX)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=return_node.value))

        stmt_list[i:] = assign_nodes
        return True
    def visit_FunctionDef(self, node):
        self.function_def.append(node)
        self.return_value_name[node] = None
        self.return_name[node] = []
        self.return_no_value_name[node] = []

        self.pre_analysis = ReturnAnalysisVisitor(node)
        max_return_length = self.pre_analysis.get_func_max_return_length(node)
        while self.pre_analysis.get_func_return_count(node) > 1:
            self.generic_visit(node)
            self.pre_analysis = ReturnAnalysisVisitor(node)

        if max_return_length == 0:
            self.function_def.pop()
            return node

        # Prepend initialization of final return and append final return statement
        value_name = self.return_value_name[node]
        if value_name is not None:
            node.body.append(
                gast.Return(value=gast.Name(id=value_name,
                                            ctx=gast.Load(),
                                            annotation=None,
                                            type_comment=None)))
            init_names = [
                unique_name.generate(RETURN_VALUE_INIT_NAME)
                for i in range(max_return_length)
            ]
            assign_zero_nodes = [
                create_fill_constant_node(iname, 0.0) for iname in init_names
            ]
            if len(init_names) == 1:
                return_value_nodes = gast.Name(id=init_names[0],
                                               ctx=gast.Load(),
                                               annotation=None,
                                               type_comment=None)
            else:
                # We need to initialize return value as a tuple because control
                # flow requires some inputs or outputs have same structure
                return_value_nodes = gast.Tuple(elts=[
                    gast.Name(id=iname,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for iname in init_names
                ],
                                                ctx=gast.Load())
            assign_return_value_node = gast.Assign(targets=[
                gast.Name(id=value_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                                   value=return_value_nodes)
            node.body.insert(0, assign_return_value_node)
            node.body[:0] = assign_zero_nodes

        # Prepend no value placeholders
        for name in self.return_no_value_name[node]:
            assign_no_value_node = create_fill_constant_node(
                name, RETURN_NO_VALUE_MAGIC_NUM)
            node.body.insert(0, assign_no_value_node)

        self.function_def.pop()
        return node