def __init__(self, root_node): # Set of gast.Name or gast.Attribute for variables self.current_seen_vars = set() # List of gast.While/gast.For nodes self.current_loop = [] # List of nodes that have scope of variables. self.nodes_with_scope = [] self.blacklist_names = {"False", "True", "None"} # Mapping from gast.While/gast.For to variable nodes self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) # Mapping from gast.While/gast.For to variable nodes which is condition # of loop or being modified during the loop self.write_in_loop = defaultdict(set) self.condition_vars = defaultdict(set) self.in_condition = False self.static_analysis_visitor = StaticAnalysisVisitor(root_node) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) self.visit(root_node)
def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root
def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_to_tensor_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type()
class AssertTransformer(gast.NodeTransformer): """ A class transforms python assert to fluid.layers.Assert. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root) def transform(self): self.visit(self.root) def visit_Assert(self, node): if not self.static_analysis_visitor.is_tensor_node(node.test): return node cast_node = gast.Call( func=gast.parse("fluid.layers.cast").body[0].value, args=[node.test, gast.Constant(value="bool", kind=None)], keywords=[]) assert_node = gast.Call( func=gast.parse("fluid.layers.Assert").body[0].value, args=[cast_node], keywords=[]) return gast.Expr(value=assert_node)
def test_paddle_api_with_andOr(self): code_or = """ def foo(x): if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None : x = x + 1 return x """ code_and = """ def foo(x): if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None : x = x + 1 return x """ for code in [code_or, code_and]: code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[0].test if_visitor = IfConditionVisitor(test_node, static_analysis_visitor) self.assertTrue(if_visitor.is_control_flow()) new_node, assign_nodes = if_visitor.transform() # Transformation result: # bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1)) # bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None)) # logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=fluid.layers.shape(x)[0] > 16) # logic_and_1 = fluid.layers.logical_and(x=logic_and_0, y=bool_tensor_1) for code_and # logic_or_0= fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1) for code_and self.assertTrue(isinstance(new_node, gast.Name)) self.assertTrue(len(assign_nodes) == 4)
def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node # stores origin var string name (like "x" in `x = t.shape`) to # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type()
class PrintTransformer(gast.NodeTransformer): """ This class transforms python print function to fluid.layers.Print. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) def transform(self): self.visit(self.root) # NOTE: deal with print in PY3 def visit_Call(self, node): if isinstance(node.func, gast.Name) and node.func.id == 'print': node = self._create_print_node(node.args) return node # NOTE: deal with print in PY2 def visit_Print(self, node): convert_print_node = self._create_print_node(node.values) return gast.Expr(value=convert_print_node) def _create_print_node(self, print_args): convert_print_func = gast.parse( 'paddle.jit.dy2static.convert_print').body[0].value return gast.Call(func=convert_print_func, args=print_args, keywords=[])
def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Type of input node should be AstNodeWrapper, but received %s ." % type( wrapper_root) self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
def test_paddle_api(self): code = """ def foo(x): if fluid.layers.shape(x)[0] > 16: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[0].test self.assertTrue( is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_shape_with_andOr(self): code = """ def foo(x): batch_size = fluid.layers.shape(x) if x is not None and batch_size[0] > 16 or 2 > 1: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[1].test self.assertTrue( is_control_flow_to_transform(test_node, static_analysis_visitor))
def __init__(self, ast_node, static_analysis_visitor=None, node_var_type_map=None): assert isinstance( ast_node, gast.AST ), "Type of input node should be gast.AST, but received %s." % type( ast_node) self.ast_root = ast_node if static_analysis_visitor is None: static_analysis_visitor = StaticAnalysisVisitor(ast_node) self.static_analysis_visitor = static_analysis_visitor self.node_var_type_map = node_var_type_map self.is_control_flow_num = 0 self._compare_node_tenor_set = set()
def test_paddle_api(self): code = """ def foo(x): if fluid.layers.shape(x)[0] > 16: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[0].test if_visitor = IfConditionVisitor(test_node, static_analysis_visitor) self.assertTrue(if_visitor.is_control_flow()) # No transformation will be applied. new_node, assign_nodes = if_visitor.transform() self.assertTrue(new_node == test_node) self.assertTrue(len(assign_nodes) == 0)
def test_paddle_api_with_andOr(self): code_or = """ def foo(x): if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None : x = x + 1 return x """ code_and = """ def foo(x): if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None : x = x + 1 return x """ for code in [code_or, code_and]: code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[0].test self.assertTrue( is_control_flow_to_transform(test_node, static_analysis_visitor))
def test_shape_with_andOr(self): code = """ def foo(x): batch_size = fluid.layers.shape(x) if x is not None and batch_size[0] > 16 or 2 > 1: x = x + 1 return x """ code = textwrap.dedent(code) node = gast.parse(code) static_analysis_visitor = StaticAnalysisVisitor(node) test_node = node.body[0].body[1].test if_visitor = IfConditionVisitor(test_node, static_analysis_visitor) self.assertTrue(if_visitor.is_control_flow()) new_node, assign_nodes = if_visitor.transform() # transformation result: # bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None)) # logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=batch_size[0] > 16) # bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1)) # logic_or_0 = fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1) self.assertTrue(isinstance(new_node, gast.Name)) self.assertTrue(len(assign_nodes) == 4)
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.transform() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() # Transform list used in control flow ListTransformer(node_wrapper).transform() # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() # Transform for loop and while loop LoopTransformer(node_wrapper).transform() # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() # Transform python assert statement AssertTransformer(node_wrapper).transform() # Transform all python print statement PrintTransformer(node_wrapper).transform() # Transform call recursively CallTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name for idx, arg in enumerate(node.args.args): self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name def get_feed_name_to_idx(self): feed_name_to_idx = {} for feed_name, arg_name in self.feed_name_to_arg_name.items(): feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) return feed_name_to_idx
class NameVisitor(gast.NodeVisitor): ''' Analysis name liveness for loop transformer ''' def __init__(self, root_node): # Set of gast.Name or gast.Attribute for variables self.current_seen_vars = set() # List of gast.While/gast.For nodes self.current_loop = [] # List of nodes that have scope of variables. self.nodes_with_scope = [] self.blacklist_names = {"False", "True", "None"} # Mapping from gast.While/gast.For to variable nodes self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) # Mapping from gast.While/gast.For to variable nodes which is condition # of loop or being modified during the loop self.write_in_loop = defaultdict(set) self.condition_vars = defaultdict(set) self.in_condition = False self.static_analysis_visitor = StaticAnalysisVisitor(root_node) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) self.visit(root_node) def is_control_flow_loop(self, node): need_transform = is_control_flow_to_transform( node, self.static_analysis_visitor) return need_transform def get_loop_var_names(self, node): assert isinstance( node, (gast.While, gast.For)), "Input node is not gast loop node" loop_var_names = set() create_var_names = set() read_context = {type(gast.Load()), type(gast.AugLoad())} in_loop_vars = self.in_loop_vars[node] in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) before_loop_body_vars = self.before_loop_body_vars[node] before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars) after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars after_loop_name_strs = self._var_nodes_to_names( after_loop_vars, read_context) condition_vars = self.condition_vars[node] condition_names = self._var_nodes_to_names(condition_vars) write_vars = self.write_in_loop[node] write_names = self._var_nodes_to_names(write_vars) name_to_type = {} for var in in_loop_vars: wrapper = self.node_to_wrapper_map[var] name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type for name in in_loop_name_strs: if name in before_loop_name_strs: # If a variable is used in loop and created before loop # If this var is a basic variable and read-only and not # condition var, it may not be loop_var else it should # be in loop_var as input if (not name in condition_names) and ( not name in write_names ) and self._node_var_type_is_basic(name_to_type[name]): continue loop_var_names.add(name) elif name in after_loop_name_strs: # If a variable is created in the while loop and read after # loop, it should be in loop_var and we should create it # because name in after_loop_name must be initialized in loop # So it is write-only, we don't have to filter read-only basic # vars out loop_var_names.add(name) create_var_names.add(name) return loop_var_names, create_var_names def visit_Name(self, node): if self._is_call_func_name_node(node): self.generic_visit(node) return if node.id in self.blacklist_names: self.generic_visit(node) return self.current_seen_vars.add(node) write_context = { type(gast.Store()), type(gast.AugStore()), type(gast.Del()) } for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) if type(node.ctx) in write_context: self.write_in_loop[loop_node].add(node) if self.in_condition: self.condition_vars[loop_node].add(node) self.generic_visit(node) def visit_FunctionDef(self, node): self.nodes_with_scope.append(node) self.blacklist_names.add(node.name) # The variables in the function are not visible to the outside scope. before_func_seen_vars = copy.copy(self.current_seen_vars) self.generic_visit(node) self.nodes_with_scope.pop() # After exiting the scope of the node, variables in this scope # should be removed from self.current_seen_vars. if self.nodes_with_scope: self.current_seen_vars = before_func_seen_vars def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) ret = visitor(node) return ret def visit_Attribute(self, node): if self._is_call_func_name_node(node): return attr_full_name = get_attribute_full_name(node) # Class variables are not allowed to appear in the arguments list # of defined function under class methods in Python. """ def class_func(self): def while_loop_body(self.x, y) # `self.x` is illegal. """ # TODO: If do change the variable with `self.var`, need a better # way to deal with this case. if attr_full_name.startswith("self."): return self.current_seen_vars.add(node) for loop_node in self.current_loop: self.in_loop_vars[loop_node].add(node) # sub-nodes are visited during get_attribute_full_name and we shouldn't # visit again def visit_For(self, node): self.current_loop.append(node) self.in_condition = True self.visit(node.target) self.visit(node.iter) self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_While(self, node): self.current_loop.append(node) self.in_condition = True self.visit(node.test) self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def _var_nodes_to_names(self, node_set, ctx_filter_set=None): ret = set() for node in node_set: if ctx_filter_set is None or type(node.ctx) in ctx_filter_set: ret.add(self._var_node_to_name(node)) return ret def _var_node_to_name(self, node): if isinstance(node, gast.Name): return node.id elif isinstance(node, gast.Attribute): return get_attribute_full_name(node) def _node_var_type_is_basic(self, node_var_type): basic_types = { NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, NodeVarType.STRING } for t in node_var_type: if t in basic_types: return True return False def _is_call_func_name_node(self, node): parent_node = self.node_to_wrapper_map[node].parent.node if isinstance(parent_node, gast.Call) and parent_node.func == node: return True return False
class ListTransformer(gast.NodeTransformer): """ This class transforms python list used in control flow into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of ListTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.list_name_to_updated = dict() self.list_nodes = set() self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): self.visit(self.root) self.replace_list_with_tensor_array(self.root) def visit_Call(self, node): if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == "pop": node = self._replace_list_pop(node) return node def visit_Assign(self, node): if self._update_list_name_to_updated(node): return node if self._need_to_array_write_node(node): return self._transform_slice_to_tensor_write(node) self.generic_visit(node) return node def visit_If(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_While(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_For(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def replace_list_with_tensor_array(self, node): for child_node in gast.walk(node): if isinstance(child_node, gast.Assign): if self._need_to_create_tensor_array(child_node): child_node.value = self._create_tensor_array() def _transform_list_append_in_control_flow(self, node): for child_node in gast.walk(node): if self._need_to_array_write_node(child_node): child_node.value = \ self._to_array_write_node(child_node.value) def _need_to_array_write_node(self, node): if isinstance(node, gast.Expr): if isinstance(node.value, gast.Call): if self._is_list_append_tensor(node.value): return True if isinstance(node, gast.Assign): target_node = node.targets[0] if isinstance(target_node, gast.Subscript): list_name = ast_to_source_code(target_node.value).strip() if list_name in self.list_name_to_updated: if self.list_name_to_updated[list_name] == True: return True return False def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif isinstance(slice_node, gast.Index): value_code = ast_to_source_code(node.value) i = "fluid.layers.cast(" \ "x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node def _is_list_append_tensor(self, node): """ a.append(b): a is list, b is Tensor self.x.append(b): self.x is list, b is Tensor """ assert isinstance(node, gast.Call) # 1. The func is `append`. if not isinstance(node.func, gast.Attribute): return False if node.func.attr != 'append': return False # 2. It's a `python list` to call append(). value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip() if value_name not in self.list_name_to_updated: return False # 3. The arg of append() is one `Tensor` # Only one argument is supported in Python list.append() if len(node.args) != 1: return False arg = node.args[0] if isinstance(arg, gast.Name): # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function # Need a better way to confirm whether `arg.id` is a Tensor. try: var_type_set = self.scope_var_type_dict[arg.id] except KeyError: return False if NodeVarType.NUMPY_NDARRAY in var_type_set: return False if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: return False # else: # Todo: Consider that `arg` may be a gast.Call about Paddle Api. # eg: list_a.append(fluid.layers.reshape(x)) # return True self.list_name_to_updated[value_name.strip()] = True return True def _need_to_create_tensor_array(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False if self.list_name_to_updated.get( target_id) and node in self.list_nodes: return True return False def _create_tensor_array(self): # Although `dtype='float32'`, other types such as `int32` can also be supported func_code = "fluid.layers.create_array(dtype='float32')" func_node = gast.parse(func_code).body[0].value return func_node def _to_array_write_node(self, node): assert isinstance(node, gast.Call) array = astor.to_source(gast.gast_to_ast(node.func.value)) x = astor.to_source(gast.gast_to_ast(node.args[0])) i = "fluid.layers.array_length({})".format(array) func_code = "fluid.layers.array_write(x={}, i={}, array={})".format( x, i, array) return gast.parse(func_code).body[0].value def _update_list_name_to_updated(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] # TODO: Consider node has more than one target. eg: x, y = a, [] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.List): self.list_name_to_updated[target_id] = False self.list_nodes.add(node) return True elif target_id in self.list_name_to_updated and \ self.list_name_to_updated[target_id] == False: del self.list_name_to_updated[target_id] return False def _replace_list_pop(self, node): assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() if node.args: idx_node = node.args[0] idx_str = ast_to_source_code(idx_node).strip() else: idx_str = "None" new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format( target_str, idx_str) new_call_node = gast.parse(new_call_str).body[0].value return new_call_node
class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) def visit_Assign(self, node): if self._update_name_to_var_shape(node): return node self.generic_visit(node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): if self.is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): var_shape_node = self.name_to_var_shape[node.id] return create_convert_shape_node(var_shape_node) return node def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) return node def visit_If(self, node): # Call generic_visit first to transform var.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_var_shape_if_necessary(iter) # If var.shape is a gast.Name and it is used in range function, transform it self._transform_var_shape_in_range(node) return node def _transform_var_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: args[idx] = create_convert_shape_node( self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_var_shape(child_node): var_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_var_shape: var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_convert_shape_node(var_shape_node)) break return need_transformed def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): if is_paddle_api(parent_node): return True else: return False wrapper_node = wrapper_node.parent return False def is_var_shape(self, node): """ Return True if node is like `x.shape`, return False otherwise. """ assert isinstance(node, gast.Attribute) if node.attr != 'shape': return False try: value_id = node.value.id except AttributeError: return False if value_id in self.name_to_var_shape: return True return True def _update_name_to_var_shape(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: self.name_to_var_shape[target_id] = self.name_to_var_shape[ value_node.id] return True if isinstance(value_node, gast.Attribute): if self.is_var_shape(value_node): # eg: x.shape self.name_to_var_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): if self.is_var_shape(value_node.value): # eg: x.shape[0] self.name_to_var_shape[target_id] = value_node return True return False
class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.name_to_tensor_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): self.visit(self.root) def visit_Assign(self, node): if self._update_name_to_tensor_shape(node): return node self.generic_visit(node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): if self.is_tensor_shape(node): return create_api_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_tensor_shape: if self._used_by_paddle_api(node): tensor_shape_node = self.name_to_tensor_shape[node.id] return create_api_shape_node(tensor_shape_node) return node def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. self.generic_visit(node) return node def visit_If(self, node): # Call generic_visit first to transform Tensor.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_tensor_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_tensor_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_tensor_shape_if_necessary(iter) # If tensor.shape is a gast.Name and it is used in range function, transform it self._transform_tensor_shape_in_range(node) return node def _transform_tensor_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_tensor_shape: args[idx] = create_api_shape_node( self.name_to_tensor_shape[arg.id]) return True def _transform_tensor_shape_if_necessary(self, cond): for child_node in gast.walk(cond): tensor_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_tensor_shape(child_node): tensor_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_tensor_shape: tensor_shape_node = self.name_to_tensor_shape[ child_node.id] if tensor_shape_node: wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_api_shape_node(tensor_shape_node)) break def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): if is_paddle_api(parent_node): return True else: return False wrapper_node = wrapper_node.parent return False def is_tensor_shape(self, node): """ Return True if node is like `x.shape` and x is Tensor, return False otherwise. """ assert isinstance(node, gast.Attribute) if node.attr != 'shape': return False try: value_id = node.value.id except AttributeError: return False if value_id in self.name_to_tensor_shape: return True # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function # Need a better way to confirm whether `value_id` is a Tensor. try: var_type_set = self.scope_var_type_dict[value_id] except KeyError: return False if NodeVarType.NUMPY_NDARRAY in var_type_set: return False if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: return False return True def _update_name_to_tensor_shape(self, node): assert isinstance(node, gast.Assign) # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.Name): if value_node.id in self.name_to_tensor_shape: self.name_to_tensor_shape[ target_id] = self.name_to_tensor_shape[value_node.id] return True if isinstance(value_node, gast.Attribute): if self.is_tensor_shape(value_node): # eg: x.shape self.name_to_tensor_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): if self.is_tensor_shape(value_node.value): # eg: x.shape[0] self.name_to_tensor_shape[target_id] = value_node return True return False
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.transform() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() # Transform list used in control flow ListTransformer(node_wrapper).transform() # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() # Transform for loop and while loop LoopTransformer(node_wrapper).transform() # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name for idx, arg in enumerate(node.args.args): self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [ d for d in node.decorator_list if d.id not in DECORATOR_NAMES ] node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name def get_feed_name_to_idx(self): feed_name_to_idx = {} for feed_name, arg_name in self.feed_name_to_arg_name.items(): feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) return feed_name_to_idx
class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node # stores origin var string name (like "x" in `x = t.shape`) to # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) def visit_Assign(self, node): update_static_shape_var_node = self._update_name_to_var_shape(node) if update_static_shape_var_node is not None: ret = [node] ret.extend(update_static_shape_var_node) return ret self.generic_visit(node) return node def visit_Subscript(self, node): value_node = node.value slice_node = node.slice if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( value_node): return create_choose_shape_node( value_node.id, self.name_to_var_shape[value_node.id], node) elif isinstance(value_node, gast.Attribute): if self._used_by_paddle_api(value_node): value_name = ast_to_source_code(value_node).strip() if value_name in self.name_to_var_shape: return create_choose_shape_node( value_name, self.name_to_var_shape[value_name], node) if self._is_var_shape(value_node): return create_convert_shape_node(value_node, node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): name = ast_to_source_code(node).strip() if name in self.name_to_var_shape: return create_choose_shape_node(name, self.name_to_var_shape[name]) if self._is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): return create_choose_shape_node(node.id, self.name_to_var_shape[node.id]) return node def visit_Call(self, node): if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) # Don't have to visit other APIs return node def visit_If(self, node): # Call generic_visit first to transform var.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_While(self, node): self.generic_visit(node) cond = node.test self._transform_var_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter self._transform_var_shape_if_necessary(iter) # If var.shape is a gast.Name and it is used in range function, transform it self._transform_var_shape_in_range(node) return node def _transform_var_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False if not isinstance(node.iter.func, gast.Name): return False if node.iter.func.id != "range": return False args = node.iter.args for idx, arg in enumerate(args): if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: args[idx] = create_choose_shape_node( arg.id, self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Name, gast.Attribute, gast.Subscript)): child_name = ast_to_source_code(child_node).strip() if child_name in self.name_to_var_shape: var_shape_node = create_choose_shape_node( child_name, self.name_to_var_shape[child_name]) elif self._is_var_shape(child_node): var_shape_node = child_node if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: if var_shape_node is child_node: setattr(parent_node, field, create_convert_shape_node(var_shape_node, None, True)) else: setattr(parent_node, field, var_shape_node) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: if var_shape_node is child_node: value[i] = create_convert_shape_node( var_shape_node, None, True) else: value[i] = var_shape_node has_converted_shape = True break if has_converted_shape: break return need_transformed def _used_by_paddle_api(self, node): """ Whether node is used in paddle api as arguments. For example: 1) Return True in `paddle.relu(x)` where node is `x` (gast.Name) 2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute) 3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute), because the role of node is not arguments but `gast.Call.func`. """ assert isinstance(node, (gast.Attribute, gast.Name)) wrapper_node = self.node_to_wrapper_map.get(node) if not wrapper_node: # Transformed node is not in node_to_wrapper_map return False while wrapper_node.parent: parent_node = wrapper_node.parent.node if isinstance(parent_node, gast.Call): # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`. if is_paddle_api(parent_node) and parent_node.func != node: return True else: return False wrapper_node = wrapper_node.parent return False def _is_var_shape(self, node): """ Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. """ if not isinstance(node, (gast.Attribute, gast.Subscript)): return False if isinstance(node, gast.Attribute): # If node is `paddle.shape`, return False if (node.attr == 'shape' and isinstance(node.value, gast.Name) and node.value.id == 'paddle'): return False if node.attr != 'shape': return False return True if isinstance(node, gast.Subscript): value_node = node.value return self._is_var_shape(value_node) return False def _update_name_to_var_shape(self, node): def replace_dot(name): # replace all '.' into '_' return name.replace('.', '_') assert isinstance(node, gast.Assign) target_node = node.targets[0] value_node = node.value update_static_shape_var_node = None if isinstance(target_node, gast.Tuple): update_static_shape_var_node = [] for idx, element in enumerate(target_node.elts): target_id = ast_to_source_code(element).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: # TODO(zhhsplendid): is context a problem for the result node of gast.parse? static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] sub_node_str = "{}[{}]".format(static_shape_value_name, idx) sub_node = gast.parse(sub_node_str).body[0].value update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer( ).visit(static_shape_value_node) sub_node_str = "{}[{}]".format( ast_to_source_code(static_shape_value_node).strip(), idx) sub_node = gast.parse(sub_node_str).body[0].value # Note(Aurelius84): Becuase static_shape_var_name is used in # eval_if_exist_else_none() as plain string, so it will not # be pasred as argument in convert_loop/ifelse. We delcare it # as global var because it has unique name. update_static_shape_var_node.append( gast.Global(names=[static_shape_var_name])) update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=sub_node)) self.name_to_var_shape[ target_id] = static_shape_var_name return update_static_shape_var_node else: target_id = ast_to_source_code(target_node).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ value_node.id] static_shape_value_node = gast.parse( static_shape_value_name).body[0].value update_static_shape_var_node = [ gast.Assign( targets=[static_shape_var_node], value=static_shape_value_node) ] self.name_to_var_shape[target_id] = static_shape_var_name elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] static_shape_var_name = unique_name.generate( replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse(static_shape_var_name).body[ 0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer().visit( static_shape_value_node) # Declare static_shape_var_name as global var update_static_shape_var_node = [ gast.Global(names=[static_shape_var_name]) ] update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], value=static_shape_value_node)) self.name_to_var_shape[target_id] = static_shape_var_name return update_static_shape_var_node
class PrintTransformer(gast.NodeTransformer): """ This class transforms python print function to fluid.layers.Print. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) def transform(self): self.visit(self.root) # NOTE: deal with print in PY3 def visit_Call(self, node): assert isinstance(node, gast.Call) if isinstance(node.func, gast.Name) and node.func.id == 'print': var = self._get_print_var(node) return self._construct_print_node(var) return node # NOTE: deal with print in PY2 def visit_Print(self, node): var = self._get_print_var(node) print_call_node = self._construct_print_node(var) return gast.Expr(value=print_call_node) def _get_print_var(self, node): if isinstance(node, gast.Call): var_list = node.args elif isinstance(node, gast.Print): var_list = node.values if isinstance(var_list[0], gast.Tuple): var_list = var_list[0].elts # TODO: support print multiple Var assert len(var_list) == 1, "Now only support print one Variable." return var_list[0] def _construct_print_node(self, node): if isinstance(node, gast.Name): if self._is_tensor_node(node): print_node = gast.Call( func=gast.parse('fluid.layers.Print').body[0].value, args=[node], keywords=[]) return print_node else: raise TypeError( "print object type error, only support print Variable now." ) else: # TODO: may not only print with format raise NotImplementedError( "cannot transform print with format temporarily.") def _is_tensor_node(self, node): tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} wrapper_node = self.node_to_wrapper_map.get(node, None) if wrapper_node is not None: if wrapper_node.node_var_type & tensor_types: return True return False
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def __init__(self): self.translator_logger = logging_utils.TranslatorLogger() def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def _apply(self, transformer, node_wrapper, log_level): transformer(node_wrapper).transform() self.translator_logger.log_transformed_code(log_level, self.root, transformer.__name__) def transfer_from_node_type(self, node_wrapper): self.translator_logger.log( 1, "Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) transformers = [ BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) ListTransformer, # List used in control flow BreakTransformOptimizer, # optimize transfromation of break in loops BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not LoopTransformer, # for/while -> while_op IfElseTransformer, # if/else -> cond_op AssertTransformer, # assert statement PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement GradTransformer, # transform paddle.grad to paddle.gradients ] for index, transformer in enumerate(transformers): self._apply(transformer, node_wrapper, log_level=index + 1) self.translator_logger.log_transformed_code( logging_utils.LOG_AllTransformer, self.root, "All Transformers") def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name
class ListTransformer(gast.NodeTransformer): """ This class transforms python list used in control flow into Static Graph Ast. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of ListTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.list_name_to_updated = dict() self.list_nodes = set() self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) var_env = self.static_analysis_visitor.get_var_env() var_env.cur_scope = var_env.cur_scope.sub_scopes[0] self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): SplitAssignTransformer(self.root).transform() self.visit(self.root) self.replace_list_with_tensor_array(self.root) def visit_Call(self, node): if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == "pop": node = self._replace_pop(node) return node def visit_Assign(self, node): if self._update_list_name_to_updated(node): return node if self._need_to_array_write_node(node): return self._transform_slice_to_tensor_write(node) self.generic_visit(node) return node def visit_If(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_While(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def visit_For(self, node): self.generic_visit(node) if is_control_flow_to_transform(node, self.static_analysis_visitor, self.scope_var_type_dict): self._transform_list_append_in_control_flow(node) return node def replace_list_with_tensor_array(self, node): for child_node in gast.walk(node): if isinstance(child_node, gast.Assign): if self._need_to_create_tensor_array(child_node): child_node.value = self._create_tensor_array() def _transform_list_append_in_control_flow(self, node): for child_node in gast.walk(node): if self._need_to_array_write_node(child_node): child_node.value = \ self._to_array_write_node(child_node.value) def _need_to_array_write_node(self, node): if isinstance(node, gast.Expr): if isinstance(node.value, gast.Call): if self._is_list_append_tensor(node.value): return True if isinstance(node, gast.Assign): target_node = node.targets[0] if isinstance(target_node, gast.Subscript): list_name = ast_to_source_code(target_node.value).strip() if list_name in self.list_name_to_updated: if self.list_name_to_updated[list_name] == True: return True return False def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif isinstance(slice_node, gast.Index): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ "x=paddle.jit.dy2static.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node def _is_list_append_tensor(self, node): """ a.append(b): a is list, b is Tensor self.x.append(b): self.x is list, b is Tensor """ assert isinstance(node, gast.Call) # 1. The func is `append`. if not isinstance(node.func, gast.Attribute): return False if node.func.attr != 'append': return False # 2. It's a `python list` to call append(). value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip() if value_name not in self.list_name_to_updated: return False # 3. The number of arg of append() is one # Only one argument is supported in Python list.append() if len(node.args) != 1: return False # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis, # the arg is not required to be Tensor here. # 4. The arg of append() is Tensor # arg = node.args[0] # if isinstance(arg, gast.Name): # # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function # # Need a better way to confirm whether `arg.id` is a Tensor. # try: # var_type_set = self.scope_var_type_dict[arg.id] # except KeyError: # return False # if NodeVarType.NUMPY_NDARRAY in var_type_set: # return False # if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: # return False # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x)) # # else: # # return True self.list_name_to_updated[value_name.strip()] = True return True def _need_to_create_tensor_array(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] try: target_id = target_node.id except AttributeError: return False if self.list_name_to_updated.get( target_id) and node in self.list_nodes: return True return False def _create_tensor_array(self): # Although `dtype='float32'`, other types such as `int32` can also be supported func_code = "fluid.layers.create_array(dtype='float32')" func_node = gast.parse(func_code).body[0].value return func_node def _to_array_write_node(self, node): assert isinstance(node, gast.Call) array = astor.to_source(gast.gast_to_ast(node.func.value)) x = astor.to_source(gast.gast_to_ast(node.args[0])) i = "fluid.layers.array_length({})".format(array) func_code = "fluid.layers.array_write(x={}, i={}, array={})".format( x, i, array) return gast.parse(func_code).body[0].value def _update_list_name_to_updated(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]` try: target_id = target_node.id except AttributeError: return False value_node = node.value if isinstance(value_node, gast.List): self.list_name_to_updated[target_id] = False self.list_nodes.add(node) return True elif target_id in self.list_name_to_updated and \ self.list_name_to_updated[target_id] == False: del self.list_name_to_updated[target_id] return False def _replace_pop(self, node): """ Replace a pop statement for a list or dict. For example: list_a = [0,1,2,3,4] x = list_a.pop() # --> convert_pop(list_a) y = list_a.pop(1) # --> convert_pop(list_a, 1) dict_a = {"red":0, "blue":1, "yellow":2} m = dict_a.pop("red") # --> convert_pop(dict_a, "red") n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3) """ assert isinstance(node, gast.Call) assert isinstance(node.func, gast.Attribute) target_node = node.func.value target_str = ast_to_source_code(target_node).strip() args_str = [ast_to_source_code(arg).strip() for arg in node.args] # NOTE(liym27): # 1. pop stmt for a list if len(args_str) == 0 # 2. pop stmt for a list or dict if len(args_str) == 1 # 3. pop stmt for a dict if len(args_str) == 2 if len(args_str) <= 2: new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\ .format(target_str, ",".join(args_str)) new_pop_node = gast.parse(new_pop_str).body[0].value return new_pop_node else: return node
class NameVisitor(gast.NodeVisitor): ''' Analysis name liveness for loop transformer ''' def __init__(self, root_node): # Set of gast.Name or gast.Attribute for variables self.current_seen_vars = set() # List of gast.While/gast.For nodes self.current_loop = [] # List of nodes that have scope of variables. self.nodes_with_scope = [] self.blacklist_names = {"False", "True", "None"} # Mapping from gast.While/gast.For to variable nodes self.before_loop_body_vars = defaultdict(set) # NOTE: Use ordered list as dict value self.in_loop_vars = defaultdict(list) # Mapping from gast.While/gast.For to variable nodes which is condition # of loop or being modified during the loop self.write_in_loop = defaultdict(set) self.condition_vars = defaultdict(set) self.in_condition = False # Some names are types, we shouldn't record them as loop var names. self.type_vars = set() self.static_analysis_visitor = StaticAnalysisVisitor(root_node) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) self.visit(root_node) def get_loop_var_names(self, node): assert isinstance( node, (gast.While, gast.For)), "Input node is not gast loop node" loop_var_names = set() create_var_names = set() read_context = {type(gast.Load()), type(gast.AugLoad())} in_loop_vars_list = self.in_loop_vars[node] # get dict `var_name_to_ctxs` var_name_to_ctxs = defaultdict(list) for var_node in in_loop_vars_list: var_name_to_ctxs[self._var_node_to_name(var_node)].append( var_node.ctx) in_loop_vars = set(in_loop_vars_list) in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node) in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) before_loop_body_vars = self.before_loop_body_vars[node] before_loop_body_vars = self._remove_unnecessary_vars( before_loop_body_vars, node) before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars) after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node) after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, read_context) condition_vars = self.condition_vars[node] condition_names = self._var_nodes_to_names(condition_vars) write_vars = self.write_in_loop[node] write_names = self._var_nodes_to_names(write_vars) name_to_type = {} for var in in_loop_vars: wrapper = self.node_to_wrapper_map[var] name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type for name in in_loop_name_strs: if name in before_loop_name_strs: # If a variable is used in loop and created before loop # If this var is a basic variable and read-only and not # condition var, it may not be loop_var else it should # be in loop_var as input if (not name in condition_names) and ( not name in write_names ) and self._node_var_type_is_basic(name_to_type[name]): continue loop_var_names.add(name) elif name in after_loop_name_strs: # If a variable is created in the while loop and read after # loop, it should be in loop_var and we should create it # because name in after_loop_name must be initialized in loop # So it is write-only, we don't have to filter read-only basic # vars out loop_var_names.add(name) create_var_names.add(name) else: # If a variable is used and created in loop, but used before created, # it should be in loop_var and we should create it. # For example, `var_a` should be in loop_var and we should create it. # # res = 0 # for i, x in enumerate(x_array): # if i > 2: # x = func1(var_a) # var_a = func2(x) # is_created = False for ctx in var_name_to_ctxs[name]: if isinstance(ctx, gast.Store): is_created = True if isinstance(var_name_to_ctxs[name][0], gast.Load) and is_created: loop_var_names.add(name) create_var_names.add(name) return loop_var_names, create_var_names def visit_Name(self, node): if self._is_call_func_name_node(node): self.generic_visit(node) return if node.id in self.blacklist_names: self.generic_visit(node) return self.current_seen_vars.add(node) write_context = { type(gast.Store()), type(gast.AugStore()), type(gast.Del()) } for loop_node in self.current_loop: self.in_loop_vars[loop_node].append(node) if type(node.ctx) in write_context: self.write_in_loop[loop_node].add(node) if self.in_condition: self.condition_vars[loop_node].add(node) self.generic_visit(node) def visit_FunctionDef(self, node): self.nodes_with_scope.append(node) self.blacklist_names.add(node.name) # The variables in the function are not visible to the outside scope. before_func_seen_vars = copy.copy(self.current_seen_vars) self.generic_visit(node) self.nodes_with_scope.pop() # After exiting the scope of the node, variables in this scope # should be removed from self.current_seen_vars. if self.nodes_with_scope: self.current_seen_vars = before_func_seen_vars def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) ret = visitor(node) return ret def visit_Attribute(self, node): if self._is_call_func_name_node(node): return attr_full_name = get_attribute_full_name(node) # Class variables are not allowed to appear in the arguments list # of defined function under class methods in Python. """ def class_func(self): def while_loop_body(self.x, y) # `self.x` is illegal. """ # TODO: If do change the variable with `self.var`, need a better # way to deal with this case. if attr_full_name.startswith("self."): return self.current_seen_vars.add(node) for loop_node in self.current_loop: self.in_loop_vars[loop_node].append(node) # sub-nodes are visited during get_attribute_full_name and we shouldn't # visit again def visit_For(self, node): self.current_loop.append(node) self.in_condition = True self.visit(node.target) self.visit(node.iter) self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_While(self, node): self.current_loop.append(node) self.in_condition = True self.visit(node.test) self.in_condition = False self.before_loop_body_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_Call(self, node): # Store type var names such as "isinstance(x, some_type_names)" and # Remove them later if isinstance(node.func, gast.Name) and node.func.id == 'isinstance': type_node = node.args[1] if isinstance(type_node, gast.Tuple): for element in type_node.elts: self.type_vars.add(ast_to_source_code(element).strip()) else: self.type_vars.add(ast_to_source_code(type_node).strip()) self.generic_visit(node) def _var_nodes_to_names(self, node_set, ctx_filter_set=None): ret = set() for node in node_set: if ctx_filter_set is None or type(node.ctx) in ctx_filter_set: ret.add(self._var_node_to_name(node)) return ret def _var_node_to_name(self, node): if isinstance(node, gast.Name): return node.id elif isinstance(node, gast.Attribute): return get_attribute_full_name(node) def _node_var_type_is_basic(self, node_var_type): basic_types = { NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, NodeVarType.STRING } for t in node_var_type: if t in basic_types: return True return False def _is_call_func_name_node(self, node): parent_node = self._get_parent_node(node) if isinstance(parent_node, gast.Call) and parent_node.func == node: return True return False def _is_ancestor_node(self, ancestor_node, node): parent_node = self._get_parent_node(node) while parent_node is not None: if parent_node == ancestor_node: return True parent_node = self._get_parent_node(parent_node) return False def _get_parent_node(self, node): wrapper_node = self.node_to_wrapper_map.get(node) if wrapper_node: if wrapper_node.parent: parent_node = wrapper_node.parent.node return parent_node return None def _remove_unnecessary_vars(self, loop_vars, loop_node): """ Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node. 1. Remove target vars of gast.For from before_loop_vars or after_loop_vars. 2. Remove vars only in gast.comprehension. 3. Remove vars that are type names, for example: "isinstance(x, var_type_name)" :param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node. :param loop_node: Current loop node. """ def filter_name_nodes_from(root_node, target_var_names): """ Filter children with gast.Name type from node.(inclusivly) """ name_nodes = set() if isinstance(root_node, gast.Name): if node.id in target_var_names: name_nodes.add(root_node) for child_node in gast.walk(root_node): if isinstance(child_node, gast.Name): if child_node.id in target_var_names: name_nodes.add(child_node) return name_nodes vars_of_list_generator = set() target_vars_of_for_node = set() for name_node in loop_vars: if not isinstance(name_node, gast.Name): continue parent_node = self._get_parent_node(name_node) # NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple. # For examples: # 1) `for i, j in enumerate(x)` has two target vars: i and j # 2) `[x for x,y in array]` has two target vars: x and y if isinstance(parent_node, gast.Tuple): parent_node = self._get_parent_node(parent_node) # 1. Get vars only in gast.comprehension. # For examples: # 1) [x for x,y in array] -> x, x, y # 2) [f(x) for x in array] -> x # 3) [func(x, y) for x in array] -> x, x if isinstance(parent_node, gast.comprehension): # 1.1 target vars in list/set comprehensions target_node = parent_node.target if isinstance(target_node, gast.Tuple): target_vars = target_node.elts else: target_vars = [target_node] vars_of_list_generator = vars_of_list_generator | set( target_vars) # 1.2 vars from target vars used in elt_node target_var_names = {var.id for var in target_vars} comp_node = self._get_parent_node(parent_node) elt_nodes = [] if isinstance(comp_node, gast.ListComp): elt_nodes.append(comp_node.elt) elif isinstance(comp_node, gast.DictComp): elt_nodes.extend([comp_node.key, comp_node.value]) for node in elt_nodes: vars_of_list_generator |= filter_name_nodes_from( node, target_var_names) # 2. Get target vars or vars from target vars used in for-loop but the for-loop is # 1) not the "loop_node" itself # 2) not the ancestor of the "loop_node" # # For examples: # for k in range(x): # if it's this "loop_node", i or j both should be target vars. # # do something # # for i in range(a): # if it's this "loop_node", k or j should be in target vars but i should not. # for j in range(a): # if it's this "loop_node", k should be in target_vars but i or j should not. # x = i+j elif isinstance(parent_node, gast.For): if parent_node is loop_node: continue if self._is_ancestor_node(parent_node, loop_node): continue # 2.1 target vars in gast.For node. target_node = parent_node.target if isinstance(target_node, gast.Tuple): target_vars = target_node.elts else: target_vars = [target_node] target_vars_of_for_node = target_vars_of_for_node | set( target_vars) # 2.2 vars from target vars used in for-loop target_vars_name_strs = {var.id for var in target_vars_of_for_node} for var in loop_vars: if not isinstance(var, gast.Name): continue if var.id in target_vars_name_strs and var not in self.condition_vars[ loop_node]: target_vars_of_for_node.add(var) removed_vars = target_vars_of_for_node | vars_of_list_generator # 3. Remove var type names which are stored in self.type_vars for var in loop_vars: if ast_to_source_code(var).strip() in self.type_vars: removed_vars.add(var) return loop_vars - removed_vars
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): translator_logger = logging_utils.TranslatorLogger() translator_logger.log( 1, " Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name BasicApiTransformer(node_wrapper).transform() translator_logger.log_transformed_code(1, self.root, "BasicApiTransformer") # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() translator_logger.log_transformed_code(2, self.root, "TensorShapeTransformer") # Transform list used in control flow ListTransformer(node_wrapper).transform() translator_logger.log_transformed_code(3, self.root, "ListTransformer") # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() translator_logger.log_transformed_code(4, self.root, "BreakContinueTransformer") # Transform return in functions ReturnTransformer(node_wrapper).transform() translator_logger.log_transformed_code(5, self.root, "ReturnTransformer") # Transform logical and/or/not LogicalTransformer(node_wrapper).transform() translator_logger.log_transformed_code(6, self.root, "LogicalTransformer") # Transform for loop and while loop LoopTransformer(node_wrapper).transform() translator_logger.log_transformed_code(7, self.root, "LoopTransformer") # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() translator_logger.log_transformed_code(8, self.root, "IfElseTransformer") # Transform python assert statement AssertTransformer(node_wrapper).transform() translator_logger.log_transformed_code(9, self.root, "AssertTransformer") # Transform all python print statement PrintTransformer(node_wrapper).transform() translator_logger.log_transformed_code(10, self.root, "PrintTransformer") # Transform call recursively CallTransformer(node_wrapper).transform() translator_logger.log_transformed_code(11, self.root, "CallTransformer") # Transform python type casting statement CastTransformer(node_wrapper).transform() translator_logger.log_transformed_code(12, self.root, "CastTransformer") translator_logger.log_transformed_code( logging_utils.LOG_AllTransformer, self.root, "All Transformers") def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name
class PrintTransformer(gast.NodeTransformer): """ This class transforms python print function to fluid.layers.Print. """ def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( ) def transform(self): self.visit(self.root) # NOTE: deal with print in PY3 def visit_Call(self, node): if isinstance(node.func, gast.Name) and node.func.id == 'print': parent_node = self.node_to_wrapper_map[node].parent.node if isinstance(parent_node, gast.Expr): # NOTE: why need transform to gast.Assign node # only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True) print_assign_node = self._create_assign_node(node) if print_assign_node is not None: return print_assign_node else: return self._transform_call_node(node) return node # NOTE: deal with print in PY2 def visit_Print(self, node): print_assign_node = self._create_assign_node(node) if print_assign_node is not None: return print_assign_node return node def _transform_call_node(self, node): assert isinstance(node, gast.Call), "visit Node is not gast.Call node." var_node = self._get_print_var_node(node) if var_node is None: return node if self._need_transform(var_node, node): return self._build_print_call_node(var_node) return node def _create_assign_node(self, node): var_node = self._get_print_var_node(node) if var_node is None: return None if self._need_transform(var_node, node): return gast.Assign( targets=[var_node], value=self._build_print_call_node(var_node)) return None def _build_print_call_node(self, node): return gast.Call( func=gast.parse('fluid.layers.Print').body[0].value, args=[node], keywords=[ gast.keyword( arg='summarize', value=gast.UnaryOp( op=gast.USub(), operand=gast.Constant( value=1, kind=None))), gast.keyword( arg='print_phase', value=gast.Constant( value='forward', kind=None)) ]) def _get_print_var_node(self, node): if isinstance(node, gast.Call): var_list = node.args elif isinstance(node, gast.Print): var_list = node.values if isinstance(var_list[0], gast.Tuple): var_list = var_list[0].elts # TODO: support print multiple Var if len(var_list) == 1: return var_list[0] else: _logger.warning( "ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is." % ast_to_source_code(node).strip()) return None def _need_transform(self, var_node, print_node): if isinstance(var_node, gast.Name): if self._is_tensor_node(var_node): return True else: _logger.warning( "ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is." % ast_to_source_code(print_node).strip()) else: _logger.warning( "ProgramTranslator could not transform < %s > now and will run it as-is." % ast_to_source_code(print_node).strip()) return False def _is_tensor_node(self, node): tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} wrapper_node = self.node_to_wrapper_map.get(node, None) if wrapper_node is not None: if wrapper_node.node_var_type & tensor_types: return True return False