示例#1
0
    def visit_FunctionDef(self, node):
        self.function_def.append(node)
        self.return_value_name[node] = None
        self.return_name[node] = []
        self.return_no_value_name[node] = []

        self.pre_analysis = ReturnAnalysisVisitor(node)
        max_return_length = self.pre_analysis.get_func_max_return_length(node)
        while self.pre_analysis.get_func_return_count(node) > 1:
            self.generic_visit(node)
            self.pre_analysis = ReturnAnalysisVisitor(node)

        if max_return_length == 0:
            self.function_def.pop()
            return node

        # Prepend initialization of final return and append final return statement
        value_name = self.return_value_name[node]
        if value_name is not None:
            node.body.append(
                gast.Return(value=gast.Name(id=value_name,
                                            ctx=gast.Load(),
                                            annotation=None,
                                            type_comment=None)))
            init_names = [
                unique_name.generate(RETURN_VALUE_INIT_NAME)
                for i in range(max_return_length)
            ]
            assign_zero_nodes = [
                create_fill_constant_node(iname, 0.0) for iname in init_names
            ]
            if len(init_names) == 1:
                return_value_nodes = gast.Name(id=init_names[0],
                                               ctx=gast.Load(),
                                               annotation=None,
                                               type_comment=None)
            else:
                # We need to initialize return value as a tuple because control
                # flow requires some inputs or outputs have same structure
                return_value_nodes = gast.Tuple(elts=[
                    gast.Name(id=iname,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for iname in init_names
                ],
                                                ctx=gast.Load())
            assign_return_value_node = gast.Assign(targets=[
                gast.Name(id=value_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                                   value=return_value_nodes)
            node.body.insert(0, assign_return_value_node)
            node.body[:0] = assign_zero_nodes

        # Prepend no value placeholders
        for name in self.return_no_value_name[node]:
            assign_no_value_node = create_fill_constant_node(
                name, RETURN_NO_VALUE_MAGIC_NUM)
            node.body.insert(0, assign_no_value_node)

        self.function_def.pop()
        return node
示例#2
0
    def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
                                     max_return_length, parent_node_of_return):

        assert max_return_length >= 0, "Input illegal max_return_length"
        i = index_in_list(stmt_list, return_node)
        if i == -1:
            return False

        assign_nodes = []
        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = True'
            node_str = "{} = _jst.create_bool_as_type({}, True)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())

            assign_true_node = gast.parse(node_str).body[0]
            assign_nodes.append(assign_true_node)

        cur_func_node = self.function_def[-1]
        return_length = get_return_size(return_node)
        if return_length < max_return_length:
            # In this case we should append RETURN_NO_VALUE placeholder
            #
            # max_return_length must be >= 1 here because return_length will be
            # 0 at least.
            if self.return_value_name[cur_func_node] is None:
                self.return_value_name[cur_func_node] = unique_name.generate(
                    RETURN_VALUE_PREFIX)

            no_value_names = [
                unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
                for j in range(max_return_length - return_length)
            ]
            self.return_no_value_name[cur_func_node].extend(no_value_names)

            # Handle tuple/non-tuple case
            if max_return_length == 1:
                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Name(id=no_value_names[0],
                                                ctx=gast.Load(),
                                                annotation=None,
                                                type_comment=None)))
            else:
                # max_return_length > 1 which means we should assign tuple
                fill_tuple = [
                    gast.Name(id=n,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for n in no_value_names
                ]
                if return_node.value is not None:
                    if isinstance(return_node.value, gast.Tuple):
                        fill_tuple[:0] = return_node.value.elts
                    else:
                        fill_tuple.insert(0, return_node.value)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Tuple(elts=fill_tuple,
                                                 ctx=gast.Load())))
        else:
            # In this case we should NOT append RETURN_NO_VALUE placeholder
            if return_node.value is not None:
                cur_func_node = self.function_def[-1]
                if self.return_value_name[cur_func_node] is None:
                    self.return_value_name[
                        cur_func_node] = unique_name.generate(
                            RETURN_VALUE_PREFIX)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=return_node.value))

        stmt_list[i:] = assign_nodes
        return True
    def _update_name_to_var_shape(self, node):
        def replace_dot(name):
            # replace all '.' into '_'
            return name.replace('.', '_')

        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        value_node = node.value

        update_static_shape_var_node = None
        if isinstance(target_node, gast.Tuple):
            update_static_shape_var_node = []
            for idx, element in enumerate(target_node.elts):
                target_id = ast_to_source_code(element).strip()

                if isinstance(value_node, gast.Name):
                    if value_node.id in self.name_to_var_shape:
                        # TODO(zhhsplendid): is context a problem for the result node of gast.parse?
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_name = self.name_to_var_shape[
                            value_node.id]

                        sub_node_str = "{}[{}]".format(static_shape_value_name,
                                                       idx)
                        sub_node = gast.parse(sub_node_str).body[0].value

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))

                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
                if isinstance(value_node, gast.Attribute):
                    if self._is_var_shape(value_node):  # eg: x.shape
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_node = copy.deepcopy(value_node)
                        # x.shape becomes convert_var_shape_simple(x)
                        static_shape_value_node = ShapeAttributeTransformer(
                        ).visit(static_shape_value_node)

                        sub_node_str = "{}[{}]".format(
                            ast_to_source_code(static_shape_value_node).strip(),
                            idx)
                        sub_node = gast.parse(sub_node_str).body[0].value
                        # Note(Aurelius84): Becuase static_shape_var_name is used in
                        # eval_if_exist_else_none() as plain string, so it will not 
                        # be pasred as argument in convert_loop/ifelse. We delcare it
                        # as global var because it has unique name.
                        update_static_shape_var_node.append(
                            gast.Global(names=[static_shape_var_name]))

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))
                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
            return update_static_shape_var_node
        else:
            target_id = ast_to_source_code(target_node).strip()
            if isinstance(value_node, gast.Name):
                if value_node.id in self.name_to_var_shape:
                    static_shape_var_name = unique_name.generate(
                        replace_dot(target_id) +
                        STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                    static_shape_var_node = gast.parse(
                        static_shape_var_name).body[0].value
                    static_shape_value_name = self.name_to_var_shape[
                        value_node.id]
                    static_shape_value_node = gast.parse(
                        static_shape_value_name).body[0].value

                    update_static_shape_var_node = [
                        gast.Assign(
                            targets=[static_shape_var_node],
                            value=static_shape_value_node)
                    ]
                    self.name_to_var_shape[target_id] = static_shape_var_name
            elif self._is_var_shape(value_node):  # eg: x.shape or x.shape[0]
                static_shape_var_name = unique_name.generate(
                    replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                static_shape_var_node = gast.parse(static_shape_var_name).body[
                    0].value
                static_shape_value_node = copy.deepcopy(value_node)
                # x.shape becomes convert_var_shape_simple(x)
                static_shape_value_node = ShapeAttributeTransformer().visit(
                    static_shape_value_node)
                # Declare static_shape_var_name as global var
                update_static_shape_var_node = [
                    gast.Global(names=[static_shape_var_name])
                ]
                update_static_shape_var_node.append(
                    gast.Assign(
                        targets=[static_shape_var_node],
                        value=static_shape_value_node))
                self.name_to_var_shape[target_id] = static_shape_var_name
        return update_static_shape_var_node