class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        # stores origin var string name (like "x" in `x = t.shape`) to
        # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
        self.name_to_var_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)

    def visit_Assign(self, node):
        update_static_shape_var_node = self._update_name_to_var_shape(node)
        if update_static_shape_var_node is not None:
            ret = [node]
            ret.extend(update_static_shape_var_node)
            return ret
        self.generic_visit(node)
        return node

    def visit_Subscript(self, node):
        value_node = node.value
        slice_node = node.slice
        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
                    value_node):
                return create_choose_shape_node(
                    value_node.id, self.name_to_var_shape[value_node.id], node)
        elif isinstance(value_node, gast.Attribute):
            if self._used_by_paddle_api(value_node):
                value_name = ast_to_source_code(value_node).strip()
                if value_name in self.name_to_var_shape:
                    return create_choose_shape_node(
                        value_name, self.name_to_var_shape[value_name], node)
                if self._is_var_shape(value_node):
                    return create_convert_shape_node(value_node, node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            name = ast_to_source_code(node).strip()
            if name in self.name_to_var_shape:
                return create_choose_shape_node(name,
                                                self.name_to_var_shape[name])
            if self._is_var_shape(node):
                return create_convert_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_var_shape:
            if self._used_by_paddle_api(node):
                return create_choose_shape_node(node.id,
                                                self.name_to_var_shape[node.id])
        return node

    def visit_Call(self, node):
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
            self.generic_visit(node)
        # Don't have to visit other APIs
        return node

    def visit_If(self, node):
        # Call generic_visit first to transform var.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)

        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_var_shape_if_necessary(iter)

        # If var.shape is a gast.Name and it is used in range function, transform it
        self._transform_var_shape_in_range(node)
        return node

    def _transform_var_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
                args[idx] = create_choose_shape_node(
                    arg.id, self.name_to_var_shape[arg.id])
        return True

    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node,
                          (gast.Name, gast.Attribute, gast.Subscript)):
                child_name = ast_to_source_code(child_node).strip()
                if child_name in self.name_to_var_shape:
                    var_shape_node = create_choose_shape_node(
                        child_name, self.name_to_var_shape[child_name])
                elif self._is_var_shape(child_node):
                    var_shape_node = child_node

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        if var_shape_node is child_node:
                            setattr(parent_node, field,
                                    create_convert_shape_node(var_shape_node,
                                                              None, True))
                        else:
                            setattr(parent_node, field, var_shape_node)
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                if var_shape_node is child_node:
                                    value[i] = create_convert_shape_node(
                                        var_shape_node, None, True)
                                else:
                                    value[i] = var_shape_node
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed

    def _used_by_paddle_api(self, node):
        """
        Whether node is used in paddle api as arguments.
        For example:
            1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
            2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
            3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
               because the role of node is not arguments but `gast.Call.func`.
        """
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
                if is_paddle_api(parent_node) and parent_node.func != node:
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def _is_var_shape(self, node):
        """
        Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
        """
        if not isinstance(node, (gast.Attribute, gast.Subscript)):
            return False

        if isinstance(node, gast.Attribute):
            # If node is `paddle.shape`, return False
            if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
                    node.value.id == 'paddle'):
                return False
            if node.attr != 'shape':
                return False
            return True

        if isinstance(node, gast.Subscript):
            value_node = node.value
            return self._is_var_shape(value_node)

        return False

    def _update_name_to_var_shape(self, node):
        def replace_dot(name):
            # replace all '.' into '_'
            return name.replace('.', '_')

        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        value_node = node.value

        update_static_shape_var_node = None
        if isinstance(target_node, gast.Tuple):
            update_static_shape_var_node = []
            for idx, element in enumerate(target_node.elts):
                target_id = ast_to_source_code(element).strip()

                if isinstance(value_node, gast.Name):
                    if value_node.id in self.name_to_var_shape:
                        # TODO(zhhsplendid): is context a problem for the result node of gast.parse?
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_name = self.name_to_var_shape[
                            value_node.id]

                        sub_node_str = "{}[{}]".format(static_shape_value_name,
                                                       idx)
                        sub_node = gast.parse(sub_node_str).body[0].value

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))

                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
                if isinstance(value_node, gast.Attribute):
                    if self._is_var_shape(value_node):  # eg: x.shape
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_node = copy.deepcopy(value_node)
                        # x.shape becomes convert_var_shape_simple(x)
                        static_shape_value_node = ShapeAttributeTransformer(
                        ).visit(static_shape_value_node)

                        sub_node_str = "{}[{}]".format(
                            ast_to_source_code(static_shape_value_node).strip(),
                            idx)
                        sub_node = gast.parse(sub_node_str).body[0].value
                        # Note(Aurelius84): Becuase static_shape_var_name is used in
                        # eval_if_exist_else_none() as plain string, so it will not 
                        # be pasred as argument in convert_loop/ifelse. We delcare it
                        # as global var because it has unique name.
                        update_static_shape_var_node.append(
                            gast.Global(names=[static_shape_var_name]))

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))
                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
            return update_static_shape_var_node
        else:
            target_id = ast_to_source_code(target_node).strip()
            if isinstance(value_node, gast.Name):
                if value_node.id in self.name_to_var_shape:
                    static_shape_var_name = unique_name.generate(
                        replace_dot(target_id) +
                        STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                    static_shape_var_node = gast.parse(
                        static_shape_var_name).body[0].value
                    static_shape_value_name = self.name_to_var_shape[
                        value_node.id]
                    static_shape_value_node = gast.parse(
                        static_shape_value_name).body[0].value

                    update_static_shape_var_node = [
                        gast.Assign(
                            targets=[static_shape_var_node],
                            value=static_shape_value_node)
                    ]
                    self.name_to_var_shape[target_id] = static_shape_var_name
            elif self._is_var_shape(value_node):  # eg: x.shape or x.shape[0]
                static_shape_var_name = unique_name.generate(
                    replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                static_shape_var_node = gast.parse(static_shape_var_name).body[
                    0].value
                static_shape_value_node = copy.deepcopy(value_node)
                # x.shape becomes convert_var_shape_simple(x)
                static_shape_value_node = ShapeAttributeTransformer().visit(
                    static_shape_value_node)
                # Declare static_shape_var_name as global var
                update_static_shape_var_node = [
                    gast.Global(names=[static_shape_var_name])
                ]
                update_static_shape_var_node.append(
                    gast.Assign(
                        targets=[static_shape_var_node],
                        value=static_shape_value_node))
                self.name_to_var_shape[target_id] = static_shape_var_name
        return update_static_shape_var_node
Exemple #2
0
class ListTransformer(gast.NodeTransformer):
    """
    This class transforms python list used in control flow into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.list_name_to_updated = dict()
        self.list_nodes = set()

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)
        self.replace_list_with_tensor_array(self.root)

    def visit_Call(self, node):
        if isinstance(node.func, gast.Attribute):
            func_name = node.func.attr
            if func_name == "pop":
                node = self._replace_pop(node)
        return node

    def visit_Assign(self, node):
        if self._update_list_name_to_updated(node):
            return node

        if self._need_to_array_write_node(node):
            return self._transform_slice_to_tensor_write(node)

        self.generic_visit(node)
        return node

    def visit_If(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def replace_list_with_tensor_array(self, node):
        for child_node in gast.walk(node):
            if isinstance(child_node, gast.Assign):
                if self._need_to_create_tensor_array(child_node):
                    child_node.value = self._create_tensor_array()

    def _transform_list_append_in_control_flow(self, node):
        for child_node in gast.walk(node):
            if self._need_to_array_write_node(child_node):
                child_node.value = \
                    self._to_array_write_node(child_node.value)

    def _need_to_array_write_node(self, node):
        if isinstance(node, gast.Expr):
            if isinstance(node.value, gast.Call):
                if self._is_list_append_tensor(node.value):
                    return True

        if isinstance(node, gast.Assign):
            target_node = node.targets[0]
            if isinstance(target_node, gast.Subscript):
                list_name = ast_to_source_code(target_node.value).strip()
                if list_name in self.list_name_to_updated:
                    if self.list_name_to_updated[list_name] == True:
                        return True
        return False

    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif isinstance(slice_node, gast.Index):
            value_code = ast_to_source_code(node.value)
            i = "paddle.cast(" \
                "x=paddle.jit.dy2static.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node

    def _is_list_append_tensor(self, node):
        """
        a.append(b): a is list, b is Tensor
        self.x.append(b): self.x is list, b is Tensor
        """
        assert isinstance(node, gast.Call)
        # 1. The func is `append`.
        if not isinstance(node.func, gast.Attribute):
            return False
        if node.func.attr != 'append':
            return False

        # 2. It's a `python list` to call append().
        value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
        if value_name not in self.list_name_to_updated:
            return False

        # 3. The number of arg of append() is one
        # Only one argument is supported in Python list.append()
        if len(node.args) != 1:
            return False

        # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
        #   the arg is not required to be Tensor here.
        # 4. The arg of append() is Tensor
        # arg = node.args[0]
        # if isinstance(arg, gast.Name):
        #     # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
        #     # Need a better way to confirm whether `arg.id` is a Tensor.
        #     try:
        #         var_type_set = self.scope_var_type_dict[arg.id]
        #     except KeyError:
        #         return False
        #     if NodeVarType.NUMPY_NDARRAY in var_type_set:
        #         return False
        #     if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
        #         return False
        # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x))
        # # else:
        # # return True
        self.list_name_to_updated[value_name.strip()] = True
        return True

    def _need_to_create_tensor_array(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        if self.list_name_to_updated.get(
                target_id) and node in self.list_nodes:
            return True
        return False

    def _create_tensor_array(self):
        # Although `dtype='float32'`, other types such as `int32` can also be supported
        func_code = "fluid.layers.create_array(dtype='float32')"
        func_node = gast.parse(func_code).body[0].value
        return func_node

    def _to_array_write_node(self, node):
        assert isinstance(node, gast.Call)
        array = astor.to_source(gast.gast_to_ast(node.func.value))
        x = astor.to_source(gast.gast_to_ast(node.args[0]))
        i = "fluid.layers.array_length({})".format(array)
        func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
            x, i, array)
        return gast.parse(func_code).body[0].value

    def _update_list_name_to_updated(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]`
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value
        if isinstance(value_node, gast.List):
            self.list_name_to_updated[target_id] = False
            self.list_nodes.add(node)
            return True
        elif target_id in self.list_name_to_updated and \
                self.list_name_to_updated[target_id] == False:
            del self.list_name_to_updated[target_id]
        return False

    def _replace_pop(self, node):
        """
        Replace a pop statement for a list or dict.
        For example:

            list_a = [0,1,2,3,4]
            x = list_a.pop()  # --> convert_pop(list_a)
            y = list_a.pop(1) # --> convert_pop(list_a, 1)

            dict_a = {"red":0, "blue":1, "yellow":2}
            m = dict_a.pop("red")           # --> convert_pop(dict_a, "red")
            n = dict_a.pop("black", 3)      # --> convert_pop(dict_a, "black", 3)

        """
        assert isinstance(node, gast.Call)
        assert isinstance(node.func, gast.Attribute)

        target_node = node.func.value
        target_str = ast_to_source_code(target_node).strip()

        args_str = [ast_to_source_code(arg).strip() for arg in node.args]

        # NOTE(liym27):
        # 1. pop stmt for a list if len(args_str) == 0
        # 2. pop stmt for a list or dict if len(args_str) == 1
        # 3. pop stmt for a dict if len(args_str) == 2
        if len(args_str) <= 2:
            new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
                .format(target_str, ",".join(args_str))
            new_pop_node = gast.parse(new_pop_str).body[0].value
            return new_pop_node
        else:
            return node
Exemple #3
0
class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_to_tensor_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        self.visit(self.root)

    def visit_Assign(self, node):
        if self._update_name_to_tensor_shape(node):
            return node
        self.generic_visit(node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            if self.is_tensor_shape(node):
                return create_api_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_tensor_shape:
            if self._used_by_paddle_api(node):
                tensor_shape_node = self.name_to_tensor_shape[node.id]
                return create_api_shape_node(tensor_shape_node)
        return node

    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
            self.generic_visit(node)

        return node

    def visit_If(self, node):
        # Call generic_visit first to transform Tensor.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_tensor_shape_if_necessary(cond)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_tensor_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_tensor_shape_if_necessary(iter)

        # If tensor.shape is a gast.Name and it is used in range function, transform it
        self._transform_tensor_shape_in_range(node)
        return node

    def _transform_tensor_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg,
                          gast.Name) and arg.id in self.name_to_tensor_shape:
                args[idx] = create_api_shape_node(
                    self.name_to_tensor_shape[arg.id])

        return True

    def _transform_tensor_shape_if_necessary(self, cond):
        for child_node in gast.walk(cond):
            tensor_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_tensor_shape(child_node):
                    tensor_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_tensor_shape:
                    tensor_shape_node = self.name_to_tensor_shape[
                        child_node.id]

            if tensor_shape_node:
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_api_shape_node(tensor_shape_node))
                        break

    def _used_by_paddle_api(self, node):
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                if is_paddle_api(parent_node):
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def is_tensor_shape(self, node):
        """
        Return True if node is like `x.shape` and x is Tensor, return False otherwise.
        """
        assert isinstance(node, gast.Attribute)
        if node.attr != 'shape':
            return False

        try:
            value_id = node.value.id
        except AttributeError:
            return False

        if value_id in self.name_to_tensor_shape:
            return True

        # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
        # Need a better way to confirm whether `value_id` is a Tensor.
        try:
            var_type_set = self.scope_var_type_dict[value_id]
        except KeyError:
            return False

        if NodeVarType.NUMPY_NDARRAY in var_type_set:
            return False
        if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
            return False

        return True

    def _update_name_to_tensor_shape(self, node):
        assert isinstance(node, gast.Assign)
        # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value

        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_tensor_shape:
                self.name_to_tensor_shape[
                    target_id] = self.name_to_tensor_shape[value_node.id]
                return True
        if isinstance(value_node, gast.Attribute):
            if self.is_tensor_shape(value_node):  # eg: x.shape
                self.name_to_tensor_shape[target_id] = value_node
                return True
        if isinstance(value_node, gast.Subscript):
            if isinstance(value_node.value, gast.Attribute):
                if self.is_tensor_shape(value_node.value):  # eg: x.shape[0]
                    self.name_to_tensor_shape[target_id] = value_node
                    return True
        return False
Exemple #4
0
class ListTransformer(gast.NodeTransformer):
    """
    This class transforms python list used in control flow into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.list_name_to_updated = dict()
        self.list_nodes = set()

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        self.visit(self.root)
        self.replace_list_with_tensor_array(self.root)

    def visit_Call(self, node):
        if isinstance(node.func, gast.Attribute):
            func_name = node.func.attr
            if func_name == "pop":
                node = self._replace_list_pop(node)
        return node

    def visit_Assign(self, node):
        if self._update_list_name_to_updated(node):
            return node

        if self._need_to_array_write_node(node):
            return self._transform_slice_to_tensor_write(node)

        self.generic_visit(node)
        return node

    def visit_If(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def replace_list_with_tensor_array(self, node):
        for child_node in gast.walk(node):
            if isinstance(child_node, gast.Assign):
                if self._need_to_create_tensor_array(child_node):
                    child_node.value = self._create_tensor_array()

    def _transform_list_append_in_control_flow(self, node):
        for child_node in gast.walk(node):
            if self._need_to_array_write_node(child_node):
                child_node.value = \
                    self._to_array_write_node(child_node.value)

    def _need_to_array_write_node(self, node):
        if isinstance(node, gast.Expr):
            if isinstance(node.value, gast.Call):
                if self._is_list_append_tensor(node.value):
                    return True

        if isinstance(node, gast.Assign):
            target_node = node.targets[0]
            if isinstance(target_node, gast.Subscript):
                list_name = ast_to_source_code(target_node.value).strip()
                if list_name in self.list_name_to_updated:
                    if self.list_name_to_updated[list_name] == True:
                        return True
        return False

    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif isinstance(slice_node, gast.Index):
            value_code = ast_to_source_code(node.value)
            i = "fluid.layers.cast(" \
                "x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node

    def _is_list_append_tensor(self, node):
        """
        a.append(b): a is list, b is Tensor
        self.x.append(b): self.x is list, b is Tensor
        """
        assert isinstance(node, gast.Call)
        # 1. The func is `append`.
        if not isinstance(node.func, gast.Attribute):
            return False
        if node.func.attr != 'append':
            return False

        # 2. It's a `python list` to call append().
        value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
        if value_name not in self.list_name_to_updated:
            return False

        # 3. The arg of append() is one `Tensor`
        # Only one argument is supported in Python list.append()
        if len(node.args) != 1:
            return False
        arg = node.args[0]
        if isinstance(arg, gast.Name):
            # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
            # Need a better way to confirm whether `arg.id` is a Tensor.
            try:
                var_type_set = self.scope_var_type_dict[arg.id]
            except KeyError:
                return False

            if NodeVarType.NUMPY_NDARRAY in var_type_set:
                return False
            if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
                return False
        # else:
        # Todo: Consider that `arg` may be a gast.Call about Paddle Api.
        # eg: list_a.append(fluid.layers.reshape(x))
        # return True
        self.list_name_to_updated[value_name.strip()] = True
        return True

    def _need_to_create_tensor_array(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        if self.list_name_to_updated.get(
                target_id) and node in self.list_nodes:
            return True
        return False

    def _create_tensor_array(self):
        # Although `dtype='float32'`, other types such as `int32` can also be supported
        func_code = "fluid.layers.create_array(dtype='float32')"
        func_node = gast.parse(func_code).body[0].value
        return func_node

    def _to_array_write_node(self, node):
        assert isinstance(node, gast.Call)
        array = astor.to_source(gast.gast_to_ast(node.func.value))
        x = astor.to_source(gast.gast_to_ast(node.args[0]))
        i = "fluid.layers.array_length({})".format(array)
        func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
            x, i, array)
        return gast.parse(func_code).body[0].value

    def _update_list_name_to_updated(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        # TODO: Consider node has more than one target. eg: x, y = a, []
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value
        if isinstance(value_node, gast.List):
            self.list_name_to_updated[target_id] = False
            self.list_nodes.add(node)
            return True
        elif target_id in self.list_name_to_updated and \
                self.list_name_to_updated[target_id] == False:
            del self.list_name_to_updated[target_id]
        return False

    def _replace_list_pop(self, node):
        assert isinstance(node, gast.Call)
        assert isinstance(node.func, gast.Attribute)

        target_node = node.func.value
        target_str = ast_to_source_code(target_node).strip()

        if node.args:
            idx_node = node.args[0]
            idx_str = ast_to_source_code(idx_node).strip()
        else:
            idx_str = "None"

        new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format(
            target_str, idx_str)
        new_call_node = gast.parse(new_call_str).body[0].value
        return new_call_node
class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_to_var_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)

    def visit_Assign(self, node):
        if self._update_name_to_var_shape(node):
            return node
        self.generic_visit(node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            if self.is_var_shape(node):
                return create_convert_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_var_shape:
            if self._used_by_paddle_api(node):
                var_shape_node = self.name_to_var_shape[node.id]
                return create_convert_shape_node(var_shape_node)
        return node

    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
            self.generic_visit(node)

        return node

    def visit_If(self, node):
        # Call generic_visit first to transform var.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)

        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_var_shape_if_necessary(iter)

        # If var.shape is a gast.Name and it is used in range function, transform it
        self._transform_var_shape_in_range(node)
        return node

    def _transform_var_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
                args[idx] = create_convert_shape_node(
                    self.name_to_var_shape[arg.id])

        return True

    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_var_shape(child_node):
                    var_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_var_shape:
                    var_shape_node = self.name_to_var_shape[child_node.id]

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_convert_shape_node(var_shape_node))
                        break
        return need_transformed

    def _used_by_paddle_api(self, node):
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                if is_paddle_api(parent_node):
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def is_var_shape(self, node):
        """
        Return True if node is like `x.shape`, return False otherwise.
        """
        assert isinstance(node, gast.Attribute)

        if node.attr != 'shape':
            return False

        try:
            value_id = node.value.id
        except AttributeError:
            return False

        if value_id in self.name_to_var_shape:
            return True

        return True

    def _update_name_to_var_shape(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value

        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_var_shape:
                self.name_to_var_shape[target_id] = self.name_to_var_shape[
                    value_node.id]
                return True
        if isinstance(value_node, gast.Attribute):
            if self.is_var_shape(value_node):  # eg: x.shape
                self.name_to_var_shape[target_id] = value_node
                return True
        if isinstance(value_node, gast.Subscript):
            if isinstance(value_node.value, gast.Attribute):
                if self.is_var_shape(value_node.value):  # eg: x.shape[0]
                    self.name_to_var_shape[target_id] = value_node
                    return True
        return False