def parse_and_analyze(self, test_fn, namespace, namer=None, arg_types=None, include_type_analysis=True, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) if include_type_analysis: node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def _simple_program_ctx(self): return converter.ProgramContext( recursive=True, autograph_decorators=(), partial_types=(), autograph_module=api, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object( module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed( tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled