예제 #1
0
    def __init__(self, root_node):
        # Set of gast.Name or gast.Attribute for variables
        self.current_seen_vars = set()

        # List of gast.While/gast.For nodes
        self.current_loop = []

        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
        self.in_loop_vars = defaultdict(set)

        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

        self.visit(root_node)
예제 #2
0
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
예제 #3
0
 def get_static_ast(self, root):
     # save root for some analysis may need global AST
     self.root = root
     self.static_analysis_visitor = StaticAnalysisVisitor(root)
     self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
     )
     self.decorate_func_name = None
     self.transfer_from_node_type(self.static_analysis_root)
     return self.static_analysis_root
예제 #4
0
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_to_tensor_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()
예제 #5
0
class AssertTransformer(gast.NodeTransformer):
    """
    A class transforms python assert to fluid.layers.Assert.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)

    def transform(self):
        self.visit(self.root)

    def visit_Assert(self, node):
        if not self.static_analysis_visitor.is_tensor_node(node.test):
            return node
        cast_node = gast.Call(
            func=gast.parse("fluid.layers.cast").body[0].value,
            args=[node.test, gast.Constant(value="bool", kind=None)],
            keywords=[])
        assert_node = gast.Call(
            func=gast.parse("fluid.layers.Assert").body[0].value,
            args=[cast_node],
            keywords=[])
        return gast.Expr(value=assert_node)
예제 #6
0
    def test_paddle_api_with_andOr(self):
        code_or = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
                    x = x + 1
                return x
        """

        code_and = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None :
                    x = x + 1
                return x
        """
        for code in [code_or, code_and]:
            code = textwrap.dedent(code)
            node = gast.parse(code)
            static_analysis_visitor = StaticAnalysisVisitor(node)
            test_node = node.body[0].body[0].test
            if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
            self.assertTrue(if_visitor.is_control_flow())

            new_node, assign_nodes = if_visitor.transform()
            # Transformation result:
            # bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
            # bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
            # logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=fluid.layers.shape(x)[0] > 16)

            # logic_and_1 = fluid.layers.logical_and(x=logic_and_0, y=bool_tensor_1)  for code_and
            # logic_or_0= fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1)  for code_and

            self.assertTrue(isinstance(new_node, gast.Name))
            self.assertTrue(len(assign_nodes) == 4)
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        # stores origin var string name (like "x" in `x = t.shape`) to
        # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
        self.name_to_var_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()
예제 #8
0
class PrintTransformer(gast.NodeTransformer):
    """
    This class transforms python print function to fluid.layers.Print.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

    def transform(self):
        self.visit(self.root)

    # NOTE: deal with print in PY3
    def visit_Call(self, node):
        if isinstance(node.func, gast.Name) and node.func.id == 'print':
            node = self._create_print_node(node.args)
        return node

    # NOTE: deal with print in PY2
    def visit_Print(self, node):
        convert_print_node = self._create_print_node(node.values)
        return gast.Expr(value=convert_print_node)

    def _create_print_node(self, print_args):
        convert_print_func = gast.parse(
            'paddle.jit.dy2static.convert_print').body[0].value
        return gast.Call(func=convert_print_func, args=print_args, keywords=[])
예제 #9
0
 def __init__(self, wrapper_root):
     assert isinstance(
         wrapper_root, AstNodeWrapper
     ), "Type of input node should be AstNodeWrapper, but received %s ." % type(
         wrapper_root)
     self.root = wrapper_root.node
     self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
예제 #10
0
    def test_paddle_api(self):
        code = """
            def foo(x):
                if fluid.layers.shape(x)[0] > 16:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[0].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
예제 #11
0
    def test_shape_with_andOr(self):
        code = """
            def foo(x):
                batch_size = fluid.layers.shape(x)
                if x is not None and batch_size[0] > 16 or 2 > 1:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[1].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
예제 #12
0
    def __init__(self,
                 ast_node,
                 static_analysis_visitor=None,
                 node_var_type_map=None):
        assert isinstance(
            ast_node, gast.AST
        ), "Type of input node should be gast.AST, but received %s." % type(
            ast_node)
        self.ast_root = ast_node
        if static_analysis_visitor is None:
            static_analysis_visitor = StaticAnalysisVisitor(ast_node)
        self.static_analysis_visitor = static_analysis_visitor
        self.node_var_type_map = node_var_type_map

        self.is_control_flow_num = 0
        self._compare_node_tenor_set = set()
예제 #13
0
    def test_paddle_api(self):
        code = """
            def foo(x):
                if fluid.layers.shape(x)[0] > 16:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[0].test
        if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
        self.assertTrue(if_visitor.is_control_flow())

        # No transformation will be applied.
        new_node, assign_nodes = if_visitor.transform()
        self.assertTrue(new_node == test_node)
        self.assertTrue(len(assign_nodes) == 0)
예제 #14
0
    def test_paddle_api_with_andOr(self):
        code_or = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
                    x = x + 1
                return x
        """

        code_and = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None :
                    x = x + 1
                return x
        """
        for code in [code_or, code_and]:
            code = textwrap.dedent(code)
            node = gast.parse(code)
            static_analysis_visitor = StaticAnalysisVisitor(node)
            test_node = node.body[0].body[0].test

            self.assertTrue(
                is_control_flow_to_transform(test_node,
                                             static_analysis_visitor))
예제 #15
0
    def test_shape_with_andOr(self):
        code = """
            def foo(x):
                batch_size = fluid.layers.shape(x)
                if x is not None and batch_size[0] > 16 or 2 > 1:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[1].test
        if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
        self.assertTrue(if_visitor.is_control_flow())

        new_node, assign_nodes = if_visitor.transform()
        # transformation result:
        # bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
        # logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=batch_size[0] > 16)
        # bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
        # logic_or_0 = fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1)

        self.assertTrue(isinstance(new_node, gast.Name))
        self.assertTrue(len(assign_nodes) == 4)
예제 #16
0
class DygraphToStaticAst(gast.NodeTransformer):
    """
    Main class to transform Dygraph to Static Graph
    """
    def get_static_ast(self, root):
        # save root for some analysis may need global AST
        self.root = root
        self.static_analysis_visitor = StaticAnalysisVisitor(root)
        self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
        )
        self.decorate_func_name = None
        self.arg_name_to_idx = {}
        self.transfer_from_node_type(self.static_analysis_root)
        return self.static_analysis_root

    def transfer_from_node_type(self, node_wrapper):
        # Generic transformation
        self.visit(node_wrapper.node)

        # Transform basic api of dygraph to static graph and get feed_name_to_arg_name
        basic_api_trans = BasicApiTransformer(node_wrapper)
        basic_api_trans.transform()
        self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()

        # Transform Tensor.shape into fluid.layers.shape(Tensor)
        TensorShapeTransformer(node_wrapper).transform()

        # Transform list used in control flow
        ListTransformer(node_wrapper).transform()

        # Transform break/continue in loops
        BreakContinueTransformer(node_wrapper).transform()

        # Transform for loop and while loop
        LoopTransformer(node_wrapper).transform()

        # Transform all if/else statement of Dygraph into Static Graph.
        IfElseTransformer(node_wrapper).transform()

        # Transform python assert statement
        AssertTransformer(node_wrapper).transform()

        # Transform all python print statement
        PrintTransformer(node_wrapper).transform()

        # Transform call recursively
        CallTransformer(node_wrapper).transform()

    def visit_FunctionDef(self, node):
        if self.decorate_func_name is None:
            self.decorate_func_name = node.name
            for idx, arg in enumerate(node.args.args):
                self.arg_name_to_idx[arg.id] = idx

        self.generic_visit(node)
        # Remove the decorated name of dygraph_to_static
        if hasattr(node, 'decorator_list'):
            decorator_list = []
            for d in node.decorator_list:
                if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
                    raise NotImplementedError(
                        "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                        + d.id + " in " + self.decorate_func_name)
                if isinstance(d, gast.Attribute):
                    full_attribute_name = get_attribute_full_name(d)
                    has_translate_decorator = False
                    for deco in DECORATOR_NAMES:
                        if deco in full_attribute_name:
                            has_translate_decorator = True
                            break
                    if not has_translate_decorator:
                        raise NotImplementedError(
                            "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                            + full_attribute_name + " in " +
                            self.decorate_func_name)
            node.decorator_list = decorator_list
        return node

    def get_module_name(self):
        """
        Return the main function name which will be used as module name
        in ast_to_func.
        """
        # Should consider BaseAPITransformer which add new module name in Yamei's PR.
        assert self.decorate_func_name, "decorate_func_name shall not be None."
        return self.decorate_func_name

    def get_feed_name_to_idx(self):
        feed_name_to_idx = {}
        for feed_name, arg_name in self.feed_name_to_arg_name.items():
            feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name)
        return feed_name_to_idx
예제 #17
0
class NameVisitor(gast.NodeVisitor):
    '''
    Analysis name liveness for loop transformer
    '''
    def __init__(self, root_node):
        # Set of gast.Name or gast.Attribute for variables
        self.current_seen_vars = set()

        # List of gast.While/gast.For nodes
        self.current_loop = []

        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
        self.in_loop_vars = defaultdict(set)

        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

        self.visit(root_node)

    def is_control_flow_loop(self, node):
        need_transform = is_control_flow_to_transform(
            node, self.static_analysis_visitor)
        return need_transform

    def get_loop_var_names(self, node):
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
        loop_var_names = set()
        create_var_names = set()
        read_context = {type(gast.Load()), type(gast.AugLoad())}

        in_loop_vars = self.in_loop_vars[node]
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
        before_loop_body_vars = self.before_loop_body_vars[node]
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
        after_loop_name_strs = self._var_nodes_to_names(
            after_loop_vars, read_context)
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)
        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type

        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
                loop_var_names.add(name)

            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
                loop_var_names.add(name)
                create_var_names.add(name)
        return loop_var_names, create_var_names

    def visit_Name(self, node):
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
        if node.id in self.blacklist_names:
            self.generic_visit(node)
            return

        self.current_seen_vars.add(node)
        write_context = {
            type(gast.Store()),
            type(gast.AugStore()),
            type(gast.Del())
        }
        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].add(node)
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
            if self.in_condition:
                self.condition_vars[loop_node].add(node)
        self.generic_visit(node)

    def visit_FunctionDef(self, node):
        self.nodes_with_scope.append(node)
        self.blacklist_names.add(node.name)
        # The variables in the function are not visible to the outside scope.
        before_func_seen_vars = copy.copy(self.current_seen_vars)

        self.generic_visit(node)
        self.nodes_with_scope.pop()
        # After exiting the scope of the node, variables in this scope
        # should be removed from self.current_seen_vars.
        if self.nodes_with_scope:
            self.current_seen_vars = before_func_seen_vars

    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        return ret

    def visit_Attribute(self, node):
        if self._is_call_func_name_node(node):
            return
        attr_full_name = get_attribute_full_name(node)
        # Class variables are not allowed to appear in the arguments list
        # of defined function under class methods in Python.
        """
        def class_func(self):
            def while_loop_body(self.x, y) # `self.x` is illegal.
        """
        # TODO: If do change the variable with `self.var`, need a better
        # way to deal with this case.
        if attr_full_name.startswith("self."):
            return
        self.current_seen_vars.add(node)

        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].add(node)

        # sub-nodes are visited during get_attribute_full_name and we shouldn't
        # visit again

    def visit_For(self, node):
        self.current_loop.append(node)
        self.in_condition = True
        self.visit(node.target)
        self.visit(node.iter)
        self.in_condition = False
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_While(self, node):
        self.current_loop.append(node)
        self.in_condition = True
        self.visit(node.test)
        self.in_condition = False
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
        self.generic_visit(node)
        self.current_loop.pop()

    def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
        ret = set()
        for node in node_set:
            if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
                ret.add(self._var_node_to_name(node))
        return ret

    def _var_node_to_name(self, node):
        if isinstance(node, gast.Name):
            return node.id
        elif isinstance(node, gast.Attribute):
            return get_attribute_full_name(node)

    def _node_var_type_is_basic(self, node_var_type):
        basic_types = {
            NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
            NodeVarType.STRING
        }
        for t in node_var_type:
            if t in basic_types:
                return True
        return False

    def _is_call_func_name_node(self, node):
        parent_node = self.node_to_wrapper_map[node].parent.node
        if isinstance(parent_node, gast.Call) and parent_node.func == node:
            return True
        return False
예제 #18
0
class ListTransformer(gast.NodeTransformer):
    """
    This class transforms python list used in control flow into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.list_name_to_updated = dict()
        self.list_nodes = set()

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        self.visit(self.root)
        self.replace_list_with_tensor_array(self.root)

    def visit_Call(self, node):
        if isinstance(node.func, gast.Attribute):
            func_name = node.func.attr
            if func_name == "pop":
                node = self._replace_list_pop(node)
        return node

    def visit_Assign(self, node):
        if self._update_list_name_to_updated(node):
            return node

        if self._need_to_array_write_node(node):
            return self._transform_slice_to_tensor_write(node)

        self.generic_visit(node)
        return node

    def visit_If(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def replace_list_with_tensor_array(self, node):
        for child_node in gast.walk(node):
            if isinstance(child_node, gast.Assign):
                if self._need_to_create_tensor_array(child_node):
                    child_node.value = self._create_tensor_array()

    def _transform_list_append_in_control_flow(self, node):
        for child_node in gast.walk(node):
            if self._need_to_array_write_node(child_node):
                child_node.value = \
                    self._to_array_write_node(child_node.value)

    def _need_to_array_write_node(self, node):
        if isinstance(node, gast.Expr):
            if isinstance(node.value, gast.Call):
                if self._is_list_append_tensor(node.value):
                    return True

        if isinstance(node, gast.Assign):
            target_node = node.targets[0]
            if isinstance(target_node, gast.Subscript):
                list_name = ast_to_source_code(target_node.value).strip()
                if list_name in self.list_name_to_updated:
                    if self.list_name_to_updated[list_name] == True:
                        return True
        return False

    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif isinstance(slice_node, gast.Index):
            value_code = ast_to_source_code(node.value)
            i = "fluid.layers.cast(" \
                "x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node

    def _is_list_append_tensor(self, node):
        """
        a.append(b): a is list, b is Tensor
        self.x.append(b): self.x is list, b is Tensor
        """
        assert isinstance(node, gast.Call)
        # 1. The func is `append`.
        if not isinstance(node.func, gast.Attribute):
            return False
        if node.func.attr != 'append':
            return False

        # 2. It's a `python list` to call append().
        value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
        if value_name not in self.list_name_to_updated:
            return False

        # 3. The arg of append() is one `Tensor`
        # Only one argument is supported in Python list.append()
        if len(node.args) != 1:
            return False
        arg = node.args[0]
        if isinstance(arg, gast.Name):
            # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
            # Need a better way to confirm whether `arg.id` is a Tensor.
            try:
                var_type_set = self.scope_var_type_dict[arg.id]
            except KeyError:
                return False

            if NodeVarType.NUMPY_NDARRAY in var_type_set:
                return False
            if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
                return False
        # else:
        # Todo: Consider that `arg` may be a gast.Call about Paddle Api.
        # eg: list_a.append(fluid.layers.reshape(x))
        # return True
        self.list_name_to_updated[value_name.strip()] = True
        return True

    def _need_to_create_tensor_array(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        if self.list_name_to_updated.get(
                target_id) and node in self.list_nodes:
            return True
        return False

    def _create_tensor_array(self):
        # Although `dtype='float32'`, other types such as `int32` can also be supported
        func_code = "fluid.layers.create_array(dtype='float32')"
        func_node = gast.parse(func_code).body[0].value
        return func_node

    def _to_array_write_node(self, node):
        assert isinstance(node, gast.Call)
        array = astor.to_source(gast.gast_to_ast(node.func.value))
        x = astor.to_source(gast.gast_to_ast(node.args[0]))
        i = "fluid.layers.array_length({})".format(array)
        func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
            x, i, array)
        return gast.parse(func_code).body[0].value

    def _update_list_name_to_updated(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        # TODO: Consider node has more than one target. eg: x, y = a, []
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value
        if isinstance(value_node, gast.List):
            self.list_name_to_updated[target_id] = False
            self.list_nodes.add(node)
            return True
        elif target_id in self.list_name_to_updated and \
                self.list_name_to_updated[target_id] == False:
            del self.list_name_to_updated[target_id]
        return False

    def _replace_list_pop(self, node):
        assert isinstance(node, gast.Call)
        assert isinstance(node.func, gast.Attribute)

        target_node = node.func.value
        target_str = ast_to_source_code(target_node).strip()

        if node.args:
            idx_node = node.args[0]
            idx_str = ast_to_source_code(idx_node).strip()
        else:
            idx_str = "None"

        new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format(
            target_str, idx_str)
        new_call_node = gast.parse(new_call_str).body[0].value
        return new_call_node
class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_to_var_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)

    def visit_Assign(self, node):
        if self._update_name_to_var_shape(node):
            return node
        self.generic_visit(node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            if self.is_var_shape(node):
                return create_convert_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_var_shape:
            if self._used_by_paddle_api(node):
                var_shape_node = self.name_to_var_shape[node.id]
                return create_convert_shape_node(var_shape_node)
        return node

    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
            self.generic_visit(node)

        return node

    def visit_If(self, node):
        # Call generic_visit first to transform var.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)

        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_var_shape_if_necessary(iter)

        # If var.shape is a gast.Name and it is used in range function, transform it
        self._transform_var_shape_in_range(node)
        return node

    def _transform_var_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
                args[idx] = create_convert_shape_node(
                    self.name_to_var_shape[arg.id])

        return True

    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_var_shape(child_node):
                    var_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_var_shape:
                    var_shape_node = self.name_to_var_shape[child_node.id]

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_convert_shape_node(var_shape_node))
                        break
        return need_transformed

    def _used_by_paddle_api(self, node):
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                if is_paddle_api(parent_node):
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def is_var_shape(self, node):
        """
        Return True if node is like `x.shape`, return False otherwise.
        """
        assert isinstance(node, gast.Attribute)

        if node.attr != 'shape':
            return False

        try:
            value_id = node.value.id
        except AttributeError:
            return False

        if value_id in self.name_to_var_shape:
            return True

        return True

    def _update_name_to_var_shape(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value

        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_var_shape:
                self.name_to_var_shape[target_id] = self.name_to_var_shape[
                    value_node.id]
                return True
        if isinstance(value_node, gast.Attribute):
            if self.is_var_shape(value_node):  # eg: x.shape
                self.name_to_var_shape[target_id] = value_node
                return True
        if isinstance(value_node, gast.Subscript):
            if isinstance(value_node.value, gast.Attribute):
                if self.is_var_shape(value_node.value):  # eg: x.shape[0]
                    self.name_to_var_shape[target_id] = value_node
                    return True
        return False
예제 #20
0
class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.name_to_tensor_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        self.visit(self.root)

    def visit_Assign(self, node):
        if self._update_name_to_tensor_shape(node):
            return node
        self.generic_visit(node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            if self.is_tensor_shape(node):
                return create_api_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_tensor_shape:
            if self._used_by_paddle_api(node):
                tensor_shape_node = self.name_to_tensor_shape[node.id]
                return create_api_shape_node(tensor_shape_node)
        return node

    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
            self.generic_visit(node)

        return node

    def visit_If(self, node):
        # Call generic_visit first to transform Tensor.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_tensor_shape_if_necessary(cond)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_tensor_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_tensor_shape_if_necessary(iter)

        # If tensor.shape is a gast.Name and it is used in range function, transform it
        self._transform_tensor_shape_in_range(node)
        return node

    def _transform_tensor_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg,
                          gast.Name) and arg.id in self.name_to_tensor_shape:
                args[idx] = create_api_shape_node(
                    self.name_to_tensor_shape[arg.id])

        return True

    def _transform_tensor_shape_if_necessary(self, cond):
        for child_node in gast.walk(cond):
            tensor_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_tensor_shape(child_node):
                    tensor_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_tensor_shape:
                    tensor_shape_node = self.name_to_tensor_shape[
                        child_node.id]

            if tensor_shape_node:
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_api_shape_node(tensor_shape_node))
                        break

    def _used_by_paddle_api(self, node):
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                if is_paddle_api(parent_node):
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def is_tensor_shape(self, node):
        """
        Return True if node is like `x.shape` and x is Tensor, return False otherwise.
        """
        assert isinstance(node, gast.Attribute)
        if node.attr != 'shape':
            return False

        try:
            value_id = node.value.id
        except AttributeError:
            return False

        if value_id in self.name_to_tensor_shape:
            return True

        # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
        # Need a better way to confirm whether `value_id` is a Tensor.
        try:
            var_type_set = self.scope_var_type_dict[value_id]
        except KeyError:
            return False

        if NodeVarType.NUMPY_NDARRAY in var_type_set:
            return False
        if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
            return False

        return True

    def _update_name_to_tensor_shape(self, node):
        assert isinstance(node, gast.Assign)
        # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value

        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_tensor_shape:
                self.name_to_tensor_shape[
                    target_id] = self.name_to_tensor_shape[value_node.id]
                return True
        if isinstance(value_node, gast.Attribute):
            if self.is_tensor_shape(value_node):  # eg: x.shape
                self.name_to_tensor_shape[target_id] = value_node
                return True
        if isinstance(value_node, gast.Subscript):
            if isinstance(value_node.value, gast.Attribute):
                if self.is_tensor_shape(value_node.value):  # eg: x.shape[0]
                    self.name_to_tensor_shape[target_id] = value_node
                    return True
        return False
예제 #21
0
class DygraphToStaticAst(gast.NodeTransformer):
    """
    Main class to transform Dygraph to Static Graph
    """
    def get_static_ast(self, root):
        # save root for some analysis may need global AST
        self.root = root
        self.static_analysis_visitor = StaticAnalysisVisitor(root)
        self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
        )

        self.decorate_func_name = None
        self.arg_name_to_idx = {}
        self.transfer_from_node_type(self.static_analysis_root)
        return self.static_analysis_root

    def transfer_from_node_type(self, node_wrapper):
        # Generic transformation
        self.visit(node_wrapper.node)

        # Transform basic api of dygraph to static graph and get feed_name_to_arg_name
        basic_api_trans = BasicApiTransformer(node_wrapper)
        basic_api_trans.transform()
        self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()

        # Transform Tensor.shape into fluid.layers.shape(Tensor)
        TensorShapeTransformer(node_wrapper).transform()

        # Transform list used in control flow
        ListTransformer(node_wrapper).transform()

        # Transform break/continue in loops
        BreakContinueTransformer(node_wrapper).transform()

        # Transform for loop and while loop
        LoopTransformer(node_wrapper).transform()

        # Transform all if/else statement of Dygraph into Static Graph.
        IfElseTransformer(node_wrapper).transform()

    def visit_FunctionDef(self, node):
        if self.decorate_func_name is None:
            self.decorate_func_name = node.name
            for idx, arg in enumerate(node.args.args):
                self.arg_name_to_idx[arg.id] = idx

        self.generic_visit(node)
        # Remove the decorated name of dygraph_to_static
        if hasattr(node, 'decorator_list'):
            decorator_list = [
                d for d in node.decorator_list if d.id not in DECORATOR_NAMES
            ]
            node.decorator_list = decorator_list
        return node

    def get_module_name(self):
        """
        Return the main function name which will be used as module name
        in ast_to_func.
        """
        # Should consider BaseAPITransformer which add new module name in Yamei's PR.
        assert self.decorate_func_name, "decorate_func_name shall not be None."
        return self.decorate_func_name

    def get_feed_name_to_idx(self):
        feed_name_to_idx = {}
        for feed_name, arg_name in self.feed_name_to_arg_name.items():
            feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name)
        return feed_name_to_idx
class TensorShapeTransformer(gast.NodeTransformer):
    """
    This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        # stores origin var string name (like "x" in `x = t.shape`) to
        # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
        self.name_to_var_shape = {}

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)

    def visit_Assign(self, node):
        update_static_shape_var_node = self._update_name_to_var_shape(node)
        if update_static_shape_var_node is not None:
            ret = [node]
            ret.extend(update_static_shape_var_node)
            return ret
        self.generic_visit(node)
        return node

    def visit_Subscript(self, node):
        value_node = node.value
        slice_node = node.slice
        if isinstance(value_node, gast.Name):
            if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
                    value_node):
                return create_choose_shape_node(
                    value_node.id, self.name_to_var_shape[value_node.id], node)
        elif isinstance(value_node, gast.Attribute):
            if self._used_by_paddle_api(value_node):
                value_name = ast_to_source_code(value_node).strip()
                if value_name in self.name_to_var_shape:
                    return create_choose_shape_node(
                        value_name, self.name_to_var_shape[value_name], node)
                if self._is_var_shape(value_node):
                    return create_convert_shape_node(value_node, node)
        return node

    def visit_Attribute(self, node):
        if self._used_by_paddle_api(node):
            name = ast_to_source_code(node).strip()
            if name in self.name_to_var_shape:
                return create_choose_shape_node(name,
                                                self.name_to_var_shape[name])
            if self._is_var_shape(node):
                return create_convert_shape_node(node)
        return node

    def visit_Name(self, node):
        if node.id in self.name_to_var_shape:
            if self._used_by_paddle_api(node):
                return create_choose_shape_node(node.id,
                                                self.name_to_var_shape[node.id])
        return node

    def visit_Call(self, node):
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
            self.generic_visit(node)
        # Don't have to visit other APIs
        return node

    def visit_If(self, node):
        # Call generic_visit first to transform var.shape that is used in Paddle Api.
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)

        return node

    def visit_While(self, node):
        self.generic_visit(node)
        cond = node.test
        self._transform_var_shape_if_necessary(cond)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        iter = node.iter
        self._transform_var_shape_if_necessary(iter)

        # If var.shape is a gast.Name and it is used in range function, transform it
        self._transform_var_shape_in_range(node)
        return node

    def _transform_var_shape_in_range(self, node):
        assert isinstance(node, gast.For)
        if not isinstance(node.iter, gast.Call):
            return False
        if not isinstance(node.iter.func, gast.Name):
            return False
        if node.iter.func.id != "range":
            return False
        args = node.iter.args
        for idx, arg in enumerate(args):
            if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape:
                args[idx] = create_choose_shape_node(
                    arg.id, self.name_to_var_shape[arg.id])
        return True

    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node,
                          (gast.Name, gast.Attribute, gast.Subscript)):
                child_name = ast_to_source_code(child_node).strip()
                if child_name in self.name_to_var_shape:
                    var_shape_node = create_choose_shape_node(
                        child_name, self.name_to_var_shape[child_name])
                elif self._is_var_shape(child_node):
                    var_shape_node = child_node

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        if var_shape_node is child_node:
                            setattr(parent_node, field,
                                    create_convert_shape_node(var_shape_node,
                                                              None, True))
                        else:
                            setattr(parent_node, field, var_shape_node)
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                if var_shape_node is child_node:
                                    value[i] = create_convert_shape_node(
                                        var_shape_node, None, True)
                                else:
                                    value[i] = var_shape_node
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed

    def _used_by_paddle_api(self, node):
        """
        Whether node is used in paddle api as arguments.
        For example:
            1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
            2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
            3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
               because the role of node is not arguments but `gast.Call.func`.
        """
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
                if is_paddle_api(parent_node) and parent_node.func != node:
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False

    def _is_var_shape(self, node):
        """
        Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
        """
        if not isinstance(node, (gast.Attribute, gast.Subscript)):
            return False

        if isinstance(node, gast.Attribute):
            # If node is `paddle.shape`, return False
            if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
                    node.value.id == 'paddle'):
                return False
            if node.attr != 'shape':
                return False
            return True

        if isinstance(node, gast.Subscript):
            value_node = node.value
            return self._is_var_shape(value_node)

        return False

    def _update_name_to_var_shape(self, node):
        def replace_dot(name):
            # replace all '.' into '_'
            return name.replace('.', '_')

        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        value_node = node.value

        update_static_shape_var_node = None
        if isinstance(target_node, gast.Tuple):
            update_static_shape_var_node = []
            for idx, element in enumerate(target_node.elts):
                target_id = ast_to_source_code(element).strip()

                if isinstance(value_node, gast.Name):
                    if value_node.id in self.name_to_var_shape:
                        # TODO(zhhsplendid): is context a problem for the result node of gast.parse?
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_name = self.name_to_var_shape[
                            value_node.id]

                        sub_node_str = "{}[{}]".format(static_shape_value_name,
                                                       idx)
                        sub_node = gast.parse(sub_node_str).body[0].value

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))

                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
                if isinstance(value_node, gast.Attribute):
                    if self._is_var_shape(value_node):  # eg: x.shape
                        static_shape_var_name = unique_name.generate(
                            replace_dot(target_id) +
                            STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                        static_shape_var_node = gast.parse(
                            static_shape_var_name).body[0].value

                        static_shape_value_node = copy.deepcopy(value_node)
                        # x.shape becomes convert_var_shape_simple(x)
                        static_shape_value_node = ShapeAttributeTransformer(
                        ).visit(static_shape_value_node)

                        sub_node_str = "{}[{}]".format(
                            ast_to_source_code(static_shape_value_node).strip(),
                            idx)
                        sub_node = gast.parse(sub_node_str).body[0].value
                        # Note(Aurelius84): Becuase static_shape_var_name is used in
                        # eval_if_exist_else_none() as plain string, so it will not 
                        # be pasred as argument in convert_loop/ifelse. We delcare it
                        # as global var because it has unique name.
                        update_static_shape_var_node.append(
                            gast.Global(names=[static_shape_var_name]))

                        update_static_shape_var_node.append(
                            gast.Assign(
                                targets=[static_shape_var_node],
                                value=sub_node))
                        self.name_to_var_shape[
                            target_id] = static_shape_var_name
            return update_static_shape_var_node
        else:
            target_id = ast_to_source_code(target_node).strip()
            if isinstance(value_node, gast.Name):
                if value_node.id in self.name_to_var_shape:
                    static_shape_var_name = unique_name.generate(
                        replace_dot(target_id) +
                        STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                    static_shape_var_node = gast.parse(
                        static_shape_var_name).body[0].value
                    static_shape_value_name = self.name_to_var_shape[
                        value_node.id]
                    static_shape_value_node = gast.parse(
                        static_shape_value_name).body[0].value

                    update_static_shape_var_node = [
                        gast.Assign(
                            targets=[static_shape_var_node],
                            value=static_shape_value_node)
                    ]
                    self.name_to_var_shape[target_id] = static_shape_var_name
            elif self._is_var_shape(value_node):  # eg: x.shape or x.shape[0]
                static_shape_var_name = unique_name.generate(
                    replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
                static_shape_var_node = gast.parse(static_shape_var_name).body[
                    0].value
                static_shape_value_node = copy.deepcopy(value_node)
                # x.shape becomes convert_var_shape_simple(x)
                static_shape_value_node = ShapeAttributeTransformer().visit(
                    static_shape_value_node)
                # Declare static_shape_var_name as global var
                update_static_shape_var_node = [
                    gast.Global(names=[static_shape_var_name])
                ]
                update_static_shape_var_node.append(
                    gast.Assign(
                        targets=[static_shape_var_node],
                        value=static_shape_value_node))
                self.name_to_var_shape[target_id] = static_shape_var_name
        return update_static_shape_var_node
예제 #23
0
class PrintTransformer(gast.NodeTransformer):
    """
    This class transforms python print function to fluid.layers.Print.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

    def transform(self):
        self.visit(self.root)

    # NOTE: deal with print in PY3
    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if isinstance(node.func, gast.Name) and node.func.id == 'print':
            var = self._get_print_var(node)
            return self._construct_print_node(var)
        return node

    # NOTE: deal with print in PY2
    def visit_Print(self, node):
        var = self._get_print_var(node)
        print_call_node = self._construct_print_node(var)
        return gast.Expr(value=print_call_node)

    def _get_print_var(self, node):
        if isinstance(node, gast.Call):
            var_list = node.args
        elif isinstance(node, gast.Print):
            var_list = node.values
            if isinstance(var_list[0], gast.Tuple):
                var_list = var_list[0].elts
        # TODO: support print multiple Var
        assert len(var_list) == 1, "Now only support print one Variable."
        return var_list[0]

    def _construct_print_node(self, node):
        if isinstance(node, gast.Name):
            if self._is_tensor_node(node):
                print_node = gast.Call(
                    func=gast.parse('fluid.layers.Print').body[0].value,
                    args=[node],
                    keywords=[])
                return print_node
            else:
                raise TypeError(
                    "print object type error, only support print Variable now."
                )
        else:
            # TODO: may not only print with format
            raise NotImplementedError(
                "cannot transform print with format temporarily.")

    def _is_tensor_node(self, node):
        tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
        wrapper_node = self.node_to_wrapper_map.get(node, None)
        if wrapper_node is not None:
            if wrapper_node.node_var_type & tensor_types:
                return True
        return False
예제 #24
0
class DygraphToStaticAst(gast.NodeTransformer):
    """
    Main class to transform Dygraph to Static Graph
    """

    def __init__(self):
        self.translator_logger = logging_utils.TranslatorLogger()

    def get_static_ast(self, root):
        # save root for some analysis may need global AST
        self.root = root
        self.static_analysis_visitor = StaticAnalysisVisitor(root)
        self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
        )
        self.decorate_func_name = None
        self.transfer_from_node_type(self.static_analysis_root)
        return self.static_analysis_root

    def _apply(self, transformer, node_wrapper, log_level):
        transformer(node_wrapper).transform()
        self.translator_logger.log_transformed_code(log_level, self.root,
                                                    transformer.__name__)

    def transfer_from_node_type(self, node_wrapper):
        self.translator_logger.log(
            1, "Source code: \n{}".format(ast_to_source_code(self.root)))
        # Generic transformation
        self.visit(node_wrapper.node)

        transformers = [
            BasicApiTransformer,  # Basic Api
            TensorShapeTransformer,  # Tensor.shape -> layers.shape(Tensor)
            ListTransformer,  # List used in control flow
            BreakTransformOptimizer,  # optimize transfromation of break in loops
            BreakContinueTransformer,  # break/continue in loops
            ReturnTransformer,  # return in functions
            LogicalTransformer,  # logical and/or/not
            LoopTransformer,  # for/while -> while_op
            IfElseTransformer,  # if/else -> cond_op
            AssertTransformer,  # assert statement
            PrintTransformer,  # print statement
            CallTransformer,  # transform call recursively
            CastTransformer,  # type casting statement
            GradTransformer,  # transform paddle.grad to paddle.gradients
        ]

        for index, transformer in enumerate(transformers):
            self._apply(transformer, node_wrapper, log_level=index + 1)

        self.translator_logger.log_transformed_code(
            logging_utils.LOG_AllTransformer, self.root, "All Transformers")

    def visit_FunctionDef(self, node):
        if self.decorate_func_name is None:
            self.decorate_func_name = node.name

        self.generic_visit(node)
        # Remove the decorated name of dygraph_to_static
        if hasattr(node, 'decorator_list'):
            decorator_list = []
            for d in node.decorator_list:
                if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
                    raise NotImplementedError(
                        "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                        + d.id + " in " + self.decorate_func_name)
                if isinstance(d, gast.Attribute):
                    full_attribute_name = get_attribute_full_name(d)
                    has_translate_decorator = False
                    for deco in DECORATOR_NAMES:
                        if deco in full_attribute_name:
                            has_translate_decorator = True
                            break
                    if not has_translate_decorator:
                        raise NotImplementedError(
                            "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                            + full_attribute_name + " in " +
                            self.decorate_func_name)
            node.decorator_list = decorator_list
        return node

    def get_module_name(self):
        """
        Return the main function name which will be used as module name
        in ast_to_func.
        """
        # Should consider BaseAPITransformer which add new module name in Yamei's PR.
        assert self.decorate_func_name, "decorate_func_name shall not be None."
        return self.decorate_func_name
예제 #25
0
class ListTransformer(gast.NodeTransformer):
    """
    This class transforms python list used in control flow into Static Graph Ast.
    """
    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.list_name_to_updated = dict()
        self.list_nodes = set()

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )
        var_env = self.static_analysis_visitor.get_var_env()
        var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
        self.scope_var_type_dict = var_env.get_scope_var_type()

    def transform(self):
        SplitAssignTransformer(self.root).transform()
        self.visit(self.root)
        self.replace_list_with_tensor_array(self.root)

    def visit_Call(self, node):
        if isinstance(node.func, gast.Attribute):
            func_name = node.func.attr
            if func_name == "pop":
                node = self._replace_pop(node)
        return node

    def visit_Assign(self, node):
        if self._update_list_name_to_updated(node):
            return node

        if self._need_to_array_write_node(node):
            return self._transform_slice_to_tensor_write(node)

        self.generic_visit(node)
        return node

    def visit_If(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_While(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def visit_For(self, node):
        self.generic_visit(node)
        if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                        self.scope_var_type_dict):
            self._transform_list_append_in_control_flow(node)
        return node

    def replace_list_with_tensor_array(self, node):
        for child_node in gast.walk(node):
            if isinstance(child_node, gast.Assign):
                if self._need_to_create_tensor_array(child_node):
                    child_node.value = self._create_tensor_array()

    def _transform_list_append_in_control_flow(self, node):
        for child_node in gast.walk(node):
            if self._need_to_array_write_node(child_node):
                child_node.value = \
                    self._to_array_write_node(child_node.value)

    def _need_to_array_write_node(self, node):
        if isinstance(node, gast.Expr):
            if isinstance(node.value, gast.Call):
                if self._is_list_append_tensor(node.value):
                    return True

        if isinstance(node, gast.Assign):
            target_node = node.targets[0]
            if isinstance(target_node, gast.Subscript):
                list_name = ast_to_source_code(target_node.value).strip()
                if list_name in self.list_name_to_updated:
                    if self.list_name_to_updated[list_name] == True:
                        return True
        return False

    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif isinstance(slice_node, gast.Index):
            value_code = ast_to_source_code(node.value)
            i = "paddle.cast(" \
                "x=paddle.jit.dy2static.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node

    def _is_list_append_tensor(self, node):
        """
        a.append(b): a is list, b is Tensor
        self.x.append(b): self.x is list, b is Tensor
        """
        assert isinstance(node, gast.Call)
        # 1. The func is `append`.
        if not isinstance(node.func, gast.Attribute):
            return False
        if node.func.attr != 'append':
            return False

        # 2. It's a `python list` to call append().
        value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
        if value_name not in self.list_name_to_updated:
            return False

        # 3. The number of arg of append() is one
        # Only one argument is supported in Python list.append()
        if len(node.args) != 1:
            return False

        # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
        #   the arg is not required to be Tensor here.
        # 4. The arg of append() is Tensor
        # arg = node.args[0]
        # if isinstance(arg, gast.Name):
        #     # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
        #     # Need a better way to confirm whether `arg.id` is a Tensor.
        #     try:
        #         var_type_set = self.scope_var_type_dict[arg.id]
        #     except KeyError:
        #         return False
        #     if NodeVarType.NUMPY_NDARRAY in var_type_set:
        #         return False
        #     if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
        #         return False
        # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x))
        # # else:
        # # return True
        self.list_name_to_updated[value_name.strip()] = True
        return True

    def _need_to_create_tensor_array(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        if self.list_name_to_updated.get(
                target_id) and node in self.list_nodes:
            return True
        return False

    def _create_tensor_array(self):
        # Although `dtype='float32'`, other types such as `int32` can also be supported
        func_code = "fluid.layers.create_array(dtype='float32')"
        func_node = gast.parse(func_code).body[0].value
        return func_node

    def _to_array_write_node(self, node):
        assert isinstance(node, gast.Call)
        array = astor.to_source(gast.gast_to_ast(node.func.value))
        x = astor.to_source(gast.gast_to_ast(node.args[0]))
        i = "fluid.layers.array_length({})".format(array)
        func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
            x, i, array)
        return gast.parse(func_code).body[0].value

    def _update_list_name_to_updated(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]
        # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]`
        try:
            target_id = target_node.id
        except AttributeError:
            return False
        value_node = node.value
        if isinstance(value_node, gast.List):
            self.list_name_to_updated[target_id] = False
            self.list_nodes.add(node)
            return True
        elif target_id in self.list_name_to_updated and \
                self.list_name_to_updated[target_id] == False:
            del self.list_name_to_updated[target_id]
        return False

    def _replace_pop(self, node):
        """
        Replace a pop statement for a list or dict.
        For example:

            list_a = [0,1,2,3,4]
            x = list_a.pop()  # --> convert_pop(list_a)
            y = list_a.pop(1) # --> convert_pop(list_a, 1)

            dict_a = {"red":0, "blue":1, "yellow":2}
            m = dict_a.pop("red")           # --> convert_pop(dict_a, "red")
            n = dict_a.pop("black", 3)      # --> convert_pop(dict_a, "black", 3)

        """
        assert isinstance(node, gast.Call)
        assert isinstance(node.func, gast.Attribute)

        target_node = node.func.value
        target_str = ast_to_source_code(target_node).strip()

        args_str = [ast_to_source_code(arg).strip() for arg in node.args]

        # NOTE(liym27):
        # 1. pop stmt for a list if len(args_str) == 0
        # 2. pop stmt for a list or dict if len(args_str) == 1
        # 3. pop stmt for a dict if len(args_str) == 2
        if len(args_str) <= 2:
            new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
                .format(target_str, ",".join(args_str))
            new_pop_node = gast.parse(new_pop_str).body[0].value
            return new_pop_node
        else:
            return node
예제 #26
0
class NameVisitor(gast.NodeVisitor):
    '''
    Analysis name liveness for loop transformer
    '''

    def __init__(self, root_node):
        # Set of gast.Name or gast.Attribute for variables
        self.current_seen_vars = set()

        # List of gast.While/gast.For nodes
        self.current_loop = []

        # List of nodes that have scope of variables.
        self.nodes_with_scope = []

        self.blacklist_names = {"False", "True", "None"}

        # Mapping from gast.While/gast.For to variable nodes
        self.before_loop_body_vars = defaultdict(set)
        # NOTE: Use ordered list as dict value
        self.in_loop_vars = defaultdict(list)

        # Mapping from gast.While/gast.For to variable nodes which is condition
        # of loop or being modified during the loop
        self.write_in_loop = defaultdict(set)
        self.condition_vars = defaultdict(set)
        self.in_condition = False

        # Some names are types, we shouldn't record them as loop var names.
        self.type_vars = set()

        self.static_analysis_visitor = StaticAnalysisVisitor(root_node)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

        self.visit(root_node)

    def get_loop_var_names(self, node):
        assert isinstance(
            node, (gast.While, gast.For)), "Input node is not gast loop node"
        loop_var_names = set()
        create_var_names = set()
        read_context = {type(gast.Load()), type(gast.AugLoad())}

        in_loop_vars_list = self.in_loop_vars[node]

        # get dict `var_name_to_ctxs`
        var_name_to_ctxs = defaultdict(list)
        for var_node in in_loop_vars_list:
            var_name_to_ctxs[self._var_node_to_name(var_node)].append(
                var_node.ctx)

        in_loop_vars = set(in_loop_vars_list)
        in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node)
        in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)

        before_loop_body_vars = self.before_loop_body_vars[node]
        before_loop_body_vars = self._remove_unnecessary_vars(
            before_loop_body_vars, node)
        before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)

        after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
        after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node)
        after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
                                                        read_context)
        condition_vars = self.condition_vars[node]
        condition_names = self._var_nodes_to_names(condition_vars)

        write_vars = self.write_in_loop[node]
        write_names = self._var_nodes_to_names(write_vars)

        name_to_type = {}
        for var in in_loop_vars:
            wrapper = self.node_to_wrapper_map[var]
            name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type
        for name in in_loop_name_strs:
            if name in before_loop_name_strs:
                # If a variable is used in loop and created before loop

                # If this var is a basic variable and read-only and not
                # condition var, it may not be loop_var else it should
                # be in loop_var as input
                if (not name in condition_names) and (
                        not name in write_names
                ) and self._node_var_type_is_basic(name_to_type[name]):
                    continue
                loop_var_names.add(name)

            elif name in after_loop_name_strs:
                # If a variable is created in the while loop and read after
                # loop, it should be in loop_var and we should create it

                # because name in after_loop_name must be initialized in loop
                # So it is write-only, we don't have to filter read-only basic
                # vars out
                loop_var_names.add(name)
                create_var_names.add(name)
            else:
                # If a variable is used and created in loop, but used before created,
                # it should be in loop_var and we should create it.

                # For example, `var_a` should be in loop_var and we should create it.
                #
                #   res = 0
                #   for i, x in enumerate(x_array):
                #       if i > 2:
                #           x = func1(var_a)
                #       var_a = func2(x)
                #

                is_created = False
                for ctx in var_name_to_ctxs[name]:
                    if isinstance(ctx, gast.Store):
                        is_created = True

                if isinstance(var_name_to_ctxs[name][0],
                              gast.Load) and is_created:
                    loop_var_names.add(name)
                    create_var_names.add(name)

        return loop_var_names, create_var_names

    def visit_Name(self, node):
        if self._is_call_func_name_node(node):
            self.generic_visit(node)
            return
        if node.id in self.blacklist_names:
            self.generic_visit(node)
            return

        self.current_seen_vars.add(node)
        write_context = {
            type(gast.Store()), type(gast.AugStore()), type(gast.Del())
        }
        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].append(node)
            if type(node.ctx) in write_context:
                self.write_in_loop[loop_node].add(node)
        if self.in_condition:
            self.condition_vars[loop_node].add(node)
        self.generic_visit(node)

    def visit_FunctionDef(self, node):
        self.nodes_with_scope.append(node)
        self.blacklist_names.add(node.name)
        # The variables in the function are not visible to the outside scope.
        before_func_seen_vars = copy.copy(self.current_seen_vars)

        self.generic_visit(node)
        self.nodes_with_scope.pop()
        # After exiting the scope of the node, variables in this scope
        # should be removed from self.current_seen_vars.
        if self.nodes_with_scope:
            self.current_seen_vars = before_func_seen_vars

    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = visitor(node)
        return ret

    def visit_Attribute(self, node):
        if self._is_call_func_name_node(node):
            return
        attr_full_name = get_attribute_full_name(node)
        # Class variables are not allowed to appear in the arguments list
        # of defined function under class methods in Python.
        """
        def class_func(self):
            def while_loop_body(self.x, y) # `self.x` is illegal.
        """
        # TODO: If do change the variable with `self.var`, need a better
        # way to deal with this case.
        if attr_full_name.startswith("self."):
            return
        self.current_seen_vars.add(node)

        for loop_node in self.current_loop:
            self.in_loop_vars[loop_node].append(node)

        # sub-nodes are visited during get_attribute_full_name and we shouldn't
        # visit again

    def visit_For(self, node):
        self.current_loop.append(node)
        self.in_condition = True
        self.visit(node.target)
        self.visit(node.iter)
        self.in_condition = False
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_While(self, node):
        self.current_loop.append(node)
        self.in_condition = True
        self.visit(node.test)
        self.in_condition = False
        self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars)
        self.generic_visit(node)
        self.current_loop.pop()

    def visit_Call(self, node):
        # Store type var names such as "isinstance(x, some_type_names)" and
        # Remove them later
        if isinstance(node.func, gast.Name) and node.func.id == 'isinstance':
            type_node = node.args[1]
            if isinstance(type_node, gast.Tuple):
                for element in type_node.elts:
                    self.type_vars.add(ast_to_source_code(element).strip())
            else:
                self.type_vars.add(ast_to_source_code(type_node).strip())
        self.generic_visit(node)

    def _var_nodes_to_names(self, node_set, ctx_filter_set=None):
        ret = set()
        for node in node_set:
            if ctx_filter_set is None or type(node.ctx) in ctx_filter_set:
                ret.add(self._var_node_to_name(node))
        return ret

    def _var_node_to_name(self, node):
        if isinstance(node, gast.Name):
            return node.id
        elif isinstance(node, gast.Attribute):
            return get_attribute_full_name(node)

    def _node_var_type_is_basic(self, node_var_type):
        basic_types = {
            NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
            NodeVarType.STRING
        }
        for t in node_var_type:
            if t in basic_types:
                return True
        return False

    def _is_call_func_name_node(self, node):
        parent_node = self._get_parent_node(node)
        if isinstance(parent_node, gast.Call) and parent_node.func == node:
            return True
        return False

    def _is_ancestor_node(self, ancestor_node, node):
        parent_node = self._get_parent_node(node)

        while parent_node is not None:
            if parent_node == ancestor_node:
                return True
            parent_node = self._get_parent_node(parent_node)
        return False

    def _get_parent_node(self, node):
        wrapper_node = self.node_to_wrapper_map.get(node)
        if wrapper_node:
            if wrapper_node.parent:
                parent_node = wrapper_node.parent.node
                return parent_node
        return None

    def _remove_unnecessary_vars(self, loop_vars, loop_node):
        """
        Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
            1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
            2. Remove vars only in gast.comprehension.
            3. Remove vars that are type names, for example: "isinstance(x, var_type_name)"
        :param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
        :param loop_node: Current loop node.
        """

        def filter_name_nodes_from(root_node, target_var_names):
            """
            Filter children with gast.Name type from node.(inclusivly)
            """
            name_nodes = set()
            if isinstance(root_node, gast.Name):
                if node.id in target_var_names:
                    name_nodes.add(root_node)
            for child_node in gast.walk(root_node):
                if isinstance(child_node, gast.Name):
                    if child_node.id in target_var_names:
                        name_nodes.add(child_node)

            return name_nodes

        vars_of_list_generator = set()
        target_vars_of_for_node = set()

        for name_node in loop_vars:
            if not isinstance(name_node, gast.Name):
                continue

            parent_node = self._get_parent_node(name_node)

            # NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
            #  For examples:
            #   1) `for i, j in enumerate(x)` has two target vars: i and j
            #   2) `[x for x,y in array]` has two target vars: x and y
            if isinstance(parent_node, gast.Tuple):
                parent_node = self._get_parent_node(parent_node)

            # 1. Get vars only in gast.comprehension.
            # For examples:
            #  1) [x for x,y in array] -> x, x, y
            #  2) [f(x) for x in array] -> x
            #  3) [func(x, y) for x in array] -> x, x
            if isinstance(parent_node, gast.comprehension):
                # 1.1 target vars in list/set comprehensions
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]

                vars_of_list_generator = vars_of_list_generator | set(
                    target_vars)

                # 1.2 vars from target vars used in elt_node
                target_var_names = {var.id for var in target_vars}
                comp_node = self._get_parent_node(parent_node)
                elt_nodes = []
                if isinstance(comp_node, gast.ListComp):
                    elt_nodes.append(comp_node.elt)
                elif isinstance(comp_node, gast.DictComp):
                    elt_nodes.extend([comp_node.key, comp_node.value])

                for node in elt_nodes:
                    vars_of_list_generator |= filter_name_nodes_from(
                        node, target_var_names)

            # 2. Get target vars or vars from target vars used in for-loop but the for-loop is
            #   1) not the "loop_node" itself
            #   2) not the ancestor of the "loop_node"
            #
            # For examples:
            #   for k in range(x):   # if it's this "loop_node", i or j both should be target vars.
            #      # do something
            #
            #   for i in range(a):   # if it's this "loop_node", k or j should be in target vars but i should not.
            #     for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not.
            #       x = i+j
            elif isinstance(parent_node, gast.For):
                if parent_node is loop_node:
                    continue
                if self._is_ancestor_node(parent_node, loop_node):
                    continue
                # 2.1 target vars in gast.For node.
                target_node = parent_node.target
                if isinstance(target_node, gast.Tuple):
                    target_vars = target_node.elts
                else:
                    target_vars = [target_node]

                target_vars_of_for_node = target_vars_of_for_node | set(
                    target_vars)

        # 2.2 vars from target vars used in for-loop
        target_vars_name_strs = {var.id for var in target_vars_of_for_node}
        for var in loop_vars:
            if not isinstance(var, gast.Name):
                continue
            if var.id in target_vars_name_strs and var not in self.condition_vars[
                    loop_node]:
                target_vars_of_for_node.add(var)

        removed_vars = target_vars_of_for_node | vars_of_list_generator

        # 3. Remove var type names which are stored in self.type_vars
        for var in loop_vars:
            if ast_to_source_code(var).strip() in self.type_vars:
                removed_vars.add(var)

        return loop_vars - removed_vars
예제 #27
0
class DygraphToStaticAst(gast.NodeTransformer):
    """
    Main class to transform Dygraph to Static Graph
    """
    def get_static_ast(self, root):
        # save root for some analysis may need global AST
        self.root = root
        self.static_analysis_visitor = StaticAnalysisVisitor(root)
        self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
        )
        self.decorate_func_name = None
        self.transfer_from_node_type(self.static_analysis_root)
        return self.static_analysis_root

    def transfer_from_node_type(self, node_wrapper):
        translator_logger = logging_utils.TranslatorLogger()
        translator_logger.log(
            1, "   Source code: \n{}".format(ast_to_source_code(self.root)))
        # Generic transformation
        self.visit(node_wrapper.node)

        # Transform basic api of dygraph to static graph and get feed_name_to_arg_name
        BasicApiTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(1, self.root,
                                               "BasicApiTransformer")

        # Transform Tensor.shape into fluid.layers.shape(Tensor)
        TensorShapeTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(2, self.root,
                                               "TensorShapeTransformer")

        # Transform list used in control flow
        ListTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(3, self.root, "ListTransformer")

        # Transform break/continue in loops
        BreakContinueTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(4, self.root,
                                               "BreakContinueTransformer")

        # Transform return in functions
        ReturnTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(5, self.root,
                                               "ReturnTransformer")

        # Transform logical and/or/not
        LogicalTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(6, self.root,
                                               "LogicalTransformer")

        # Transform for loop and while loop
        LoopTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(7, self.root, "LoopTransformer")

        # Transform all if/else statement of Dygraph into Static Graph.
        IfElseTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(8, self.root,
                                               "IfElseTransformer")

        # Transform python assert statement
        AssertTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(9, self.root,
                                               "AssertTransformer")

        # Transform all python print statement
        PrintTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(10, self.root,
                                               "PrintTransformer")

        # Transform call recursively
        CallTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(11, self.root,
                                               "CallTransformer")

        # Transform python type casting statement
        CastTransformer(node_wrapper).transform()
        translator_logger.log_transformed_code(12, self.root,
                                               "CastTransformer")

        translator_logger.log_transformed_code(
            logging_utils.LOG_AllTransformer, self.root, "All Transformers")

    def visit_FunctionDef(self, node):
        if self.decorate_func_name is None:
            self.decorate_func_name = node.name

        self.generic_visit(node)
        # Remove the decorated name of dygraph_to_static
        if hasattr(node, 'decorator_list'):
            decorator_list = []
            for d in node.decorator_list:
                if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
                    raise NotImplementedError(
                        "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                        + d.id + " in " + self.decorate_func_name)
                if isinstance(d, gast.Attribute):
                    full_attribute_name = get_attribute_full_name(d)
                    has_translate_decorator = False
                    for deco in DECORATOR_NAMES:
                        if deco in full_attribute_name:
                            has_translate_decorator = True
                            break
                    if not has_translate_decorator:
                        raise NotImplementedError(
                            "ProgramTranslator hasn't implemented multiple decorators. Please remove "
                            + full_attribute_name + " in " +
                            self.decorate_func_name)
            node.decorator_list = decorator_list
        return node

    def get_module_name(self):
        """
        Return the main function name which will be used as module name
        in ast_to_func.
        """
        # Should consider BaseAPITransformer which add new module name in Yamei's PR.
        assert self.decorate_func_name, "decorate_func_name shall not be None."
        return self.decorate_func_name
예제 #28
0
class PrintTransformer(gast.NodeTransformer):
    """
    This class transforms python print function to fluid.layers.Print.
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

        self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
        self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
        )

    def transform(self):
        self.visit(self.root)

    # NOTE: deal with print in PY3
    def visit_Call(self, node):
        if isinstance(node.func, gast.Name) and node.func.id == 'print':
            parent_node = self.node_to_wrapper_map[node].parent.node
            if isinstance(parent_node, gast.Expr):
                # NOTE: why need transform to gast.Assign node
                # only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True)
                print_assign_node = self._create_assign_node(node)
                if print_assign_node is not None:
                    return print_assign_node
            else:
                return self._transform_call_node(node)
        return node

    # NOTE: deal with print in PY2
    def visit_Print(self, node):
        print_assign_node = self._create_assign_node(node)
        if print_assign_node is not None:
            return print_assign_node
        return node

    def _transform_call_node(self, node):
        assert isinstance(node, gast.Call), "visit Node is not gast.Call node."
        var_node = self._get_print_var_node(node)
        if var_node is None:
            return node
        if self._need_transform(var_node, node):
            return self._build_print_call_node(var_node)
        return node

    def _create_assign_node(self, node):
        var_node = self._get_print_var_node(node)
        if var_node is None:
            return None
        if self._need_transform(var_node, node):
            return gast.Assign(
                targets=[var_node], value=self._build_print_call_node(var_node))
        return None

    def _build_print_call_node(self, node):
        return gast.Call(
            func=gast.parse('fluid.layers.Print').body[0].value,
            args=[node],
            keywords=[
                gast.keyword(
                    arg='summarize',
                    value=gast.UnaryOp(
                        op=gast.USub(),
                        operand=gast.Constant(
                            value=1, kind=None))), gast.keyword(
                                arg='print_phase',
                                value=gast.Constant(
                                    value='forward', kind=None))
            ])

    def _get_print_var_node(self, node):
        if isinstance(node, gast.Call):
            var_list = node.args
        elif isinstance(node, gast.Print):
            var_list = node.values
            if isinstance(var_list[0], gast.Tuple):
                var_list = var_list[0].elts
        # TODO: support print multiple Var
        if len(var_list) == 1:
            return var_list[0]
        else:
            _logger.warning(
                "ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is."
                % ast_to_source_code(node).strip())
        return None

    def _need_transform(self, var_node, print_node):
        if isinstance(var_node, gast.Name):
            if self._is_tensor_node(var_node):
                return True
            else:
                _logger.warning(
                    "ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is."
                    % ast_to_source_code(print_node).strip())
        else:
            _logger.warning(
                "ProgramTranslator could not transform < %s > now and will run it as-is."
                % ast_to_source_code(print_node).strip())
        return False

    def _is_tensor_node(self, node):
        tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
        wrapper_node = self.node_to_wrapper_map.get(node, None)
        if wrapper_node is not None:
            if wrapper_node.node_var_type & tensor_types:
                return True
        return False