class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.transform() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() # Transform list used in control flow ListTransformer(node_wrapper).transform() # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() # Transform for loop and while loop LoopTransformer(node_wrapper).transform() # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() # Transform python assert statement AssertTransformer(node_wrapper).transform() # Transform all python print statement PrintTransformer(node_wrapper).transform() # Transform call recursively CallTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name for idx, arg in enumerate(node.args.args): self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name def get_feed_name_to_idx(self): feed_name_to_idx = {} for feed_name, arg_name in self.feed_name_to_arg_name.items(): feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) return feed_name_to_idx
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.transform() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() # Transform list used in control flow ListTransformer(node_wrapper).transform() # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() # Transform for loop and while loop LoopTransformer(node_wrapper).transform() # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name for idx, arg in enumerate(node.args.args): self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [ d for d in node.decorator_list if d.id not in DECORATOR_NAMES ] node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name def get_feed_name_to_idx(self): feed_name_to_idx = {} for feed_name, arg_name in self.feed_name_to_arg_name.items(): feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) return feed_name_to_idx
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def __init__(self): self.translator_logger = logging_utils.TranslatorLogger() def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def _apply(self, transformer, node_wrapper, log_level): transformer(node_wrapper).transform() self.translator_logger.log_transformed_code(log_level, self.root, transformer.__name__) def transfer_from_node_type(self, node_wrapper): self.translator_logger.log( 1, "Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) transformers = [ BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) ListTransformer, # List used in control flow BreakTransformOptimizer, # optimize transfromation of break in loops BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not LoopTransformer, # for/while -> while_op IfElseTransformer, # if/else -> cond_op AssertTransformer, # assert statement PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement GradTransformer, # transform paddle.grad to paddle.gradients ] for index, transformer in enumerate(transformers): self._apply(transformer, node_wrapper, log_level=index + 1) self.translator_logger.log_transformed_code( logging_utils.LOG_AllTransformer, self.root, "All Transformers") def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name
class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph """ def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root self.static_analysis_visitor = StaticAnalysisVisitor(root) self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node_wrapper): translator_logger = logging_utils.TranslatorLogger() translator_logger.log( 1, " Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) # Transform basic api of dygraph to static graph and get feed_name_to_arg_name BasicApiTransformer(node_wrapper).transform() translator_logger.log_transformed_code(1, self.root, "BasicApiTransformer") # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() translator_logger.log_transformed_code(2, self.root, "TensorShapeTransformer") # Transform list used in control flow ListTransformer(node_wrapper).transform() translator_logger.log_transformed_code(3, self.root, "ListTransformer") # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() translator_logger.log_transformed_code(4, self.root, "BreakContinueTransformer") # Transform return in functions ReturnTransformer(node_wrapper).transform() translator_logger.log_transformed_code(5, self.root, "ReturnTransformer") # Transform logical and/or/not LogicalTransformer(node_wrapper).transform() translator_logger.log_transformed_code(6, self.root, "LogicalTransformer") # Transform for loop and while loop LoopTransformer(node_wrapper).transform() translator_logger.log_transformed_code(7, self.root, "LoopTransformer") # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() translator_logger.log_transformed_code(8, self.root, "IfElseTransformer") # Transform python assert statement AssertTransformer(node_wrapper).transform() translator_logger.log_transformed_code(9, self.root, "AssertTransformer") # Transform all python print statement PrintTransformer(node_wrapper).transform() translator_logger.log_transformed_code(10, self.root, "PrintTransformer") # Transform call recursively CallTransformer(node_wrapper).transform() translator_logger.log_transformed_code(11, self.root, "CallTransformer") # Transform python type casting statement CastTransformer(node_wrapper).transform() translator_logger.log_transformed_code(12, self.root, "CastTransformer") translator_logger.log_transformed_code( logging_utils.LOG_AllTransformer, self.root, "All Transformers") def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name self.generic_visit(node) # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] for d in node.decorator_list: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + d.id + " in " + self.decorate_func_name) if isinstance(d, gast.Attribute): full_attribute_name = get_attribute_full_name(d) has_translate_decorator = False for deco in DECORATOR_NAMES: if deco in full_attribute_name: has_translate_decorator = True break if not has_translate_decorator: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " + full_attribute_name + " in " + self.decorate_func_name) node.decorator_list = decorator_list return node def get_module_name(self): """ Return the main function name which will be used as module name in ast_to_func. """ # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name