def convert(entity, program_ctx): """Converts an entity into an equivalent entity.""" if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): if not hasattr(entity, '__code__'): raise ValueError( 'Cannot apply autograph to a function that doesn\'t ' 'expose a __code__ object. If this is a @tf.function,' ' try passing f.python_function instead.') free_nonglobal_var_names = entity.__code__.co_freevars else: free_nonglobal_var_names = () for i, name in enumerate(free_nonglobal_var_names): if (name == 'ag__' and entity.__closure__[i].cell_contents is not ag_internal): raise ValueError('entity {} uses the reserved symbol "{}"'.format( entity, name)) # TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered. converted_entity_info = _convert_with_cache(entity, program_ctx, free_nonglobal_var_names) return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): """Creates a converted instance and binds it to match original entity.""" factory = converted_entity_info.get_factory() # `factory` is currently bound to the empty module it was loaded from. # It must instead be bound to the globals and closure from the original # entity. if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): entity_globals = entity.__globals__ entity_closure = entity.__closure__ or () elif hasattr(entity, '__module__'): entity_globals = sys.modules[entity.__module__].__dict__ entity_closure = () assert len(entity_closure) == len(free_nonglobal_var_names) # Fit the original entity's cells to match the order of factory's cells. original_names_and_cells = dict( zip(free_nonglobal_var_names, entity_closure)) new_factory_cells = tuple(original_names_and_cells[name] for name in factory.__code__.co_freevars) bound_factory = types.FunctionType(code=factory.__code__, globals=entity_globals, name=factory.__name__, argdefs=(), closure=new_factory_cells) # Two other free vars: the internal "ag__" module and the source # map. These are wired via the parameters of the factory. converted_entity = bound_factory( # pylint:disable=not-callable ag_internal, converted_entity_info.source_map, converted_entity_info.get_module()) if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): # Attach the default argument to the converted function. converted_entity.__defaults__ = entity.__defaults__ if hasattr(entity, '__kwdefaults__'): converted_entity.__kwdefaults__ = entity.__kwdefaults__ return converted_entity
def getfutureimports(entity): """Detects what future imports are necessary to safely execute entity source. Args: entity: Any object Returns: A tuple of future strings """ if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)): return tuple() return tuple( sorted(name for name, value in entity.__globals__.items() if getattr(value, '__module__', None) == '__future__'))
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: NotImplementedError: if entity is of a type that is not yet supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = convert_class_to_ast(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif tf_inspect.ismethod(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif hasattr(o, '__class__'): # Note: this should only be raised when attempting to convert the object # directly. converted_call should still support it. raise NotImplementedError( 'cannot convert entity "{}": object conversion is not yet' ' supported.'.format(o)) else: raise NotImplementedError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def islambda(f): if not tf_inspect.isfunction(f): return False if not hasattr(f, '__name__'): return False return f.__name__ == '<lambda>'
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [ gast.Name(n, ctx=gast.Load(), annotation=None, type_comment=None) for n in base_names ] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info