def get_code(self, dygraph_func): """ Returns the translated static function string code from dygraph function. Args: dygraph_func (callable): the dygraph function. Returns: str: the string code of translated static function. Examples: .. code-block:: python import paddle.fluid as fluid import numpy as np def func(x): x = fluid.dygraph.to_variable(x) if fluid.layers.mean(x) > 0: x_v = x - 1 else: x_v = x + 1 return x_v prog_trans = fluid.dygraph.ProgramTranslator() code = prog_trans.get_code(func) print(type(code)) # <class 'str'> """ assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" # Gets AST from dygraph function unwrap_func = unwrap(dygraph_func) raw_code = inspect.getsource(unwrap_func) code = textwrap.dedent(raw_code) root = gast.parse(code) # Transform AST dygraph_to_static = DygraphToStaticAst() root_wrapper = dygraph_to_static.get_static_ast(root) # Get source_code source_code = ast_to_source_code(root_wrapper.node) return source_code
def __init__(self): # Caches the converted static functions. {dygraph_func: static_func} self._converted_static_func_caches = dict() # Caches the converted ast node for same source code. {source_code: ast_root} self._code_to_ast_caches = dict() self._dygraph_to_static = DygraphToStaticAst()
class FunctionCache(object): """ Caches the transformed functions to avoid redundant conversions of the same function. """ def __init__(self): # Caches the converted static functions. {dygraph_func: static_func} self._converted_static_func_caches = dict() # Caches the converted ast node for same source code. {source_code: ast_root} self._code_to_ast_caches = dict() self._dygraph_to_static = DygraphToStaticAst() def convert_with_cache(self, func): """ Returns the cached static function or converts it when first encounters the function. """ # If hit cache, return it directly. static_func = self._converted_static_func_caches.get(func, None) if static_func is None: static_func = self._convert(func) self._converted_static_func_caches[func] = static_func return static_func def _convert(self, func): """ Converts dygraph function into static function. For two functions with same dedent code, the second function will reuse the transformed ast node of previous one. For example: # A.py def foo(x, y): z = x + y return z # B.py def foo(x, y): z = x + y return z If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo to speed up the conversion. """ # Note: In Python2, it will raise OSError when inspect function # with decorator directly and function.__wrapped__ holds the actual function. func = unwrap(func) source_code = func_to_source_code(func) # TODO(liym27): # Consider this case: source_code in self._code_to_ast_caches, # but actually they are methods in different classes. # Maybe use (__class__, source_code) as key if source_code in self._code_to_ast_caches: root_wrapper = self._code_to_ast_caches[source_code] else: root = gast.parse(source_code) root = attach_origin_info(root, func) root_wrapper = self._dygraph_to_static.get_static_ast(root) self._code_to_ast_caches[source_code] = root_wrapper # Get static function from AST static_func, file_name = ast_to_func(root_wrapper.node, func) create_and_update_origin_info_map(root_wrapper.node, static_func) return static_func def exist(self, func): return func in self._converted_static_func_caches