def call_function(self, _, func_name): # There was a direct call to a function from this builtin. It means it # was imported in the caller module in the form: from builtin import # foo. We need to add such node to be imported importFrom = ast.ImportFrom( module=self.name, names=[ast.alias(name=func_name, asname=None)], level=0) # FIXME what is level? self.exported_functions[func_name] = importFrom return func_name
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name, factory_name, closure_vars, future_features): """Wraps an AST into the body of a dynamic factory. This uses the dynamic factory (factory of factory) pattern to achieve the following: 1. The inner factory, dynamically creates the entity represented by nodes. 2. The entity is parametrized by `ag__`, the internal AutoGraph module. 3. The outer factory creates the inner factory with a lexical scope in which `closure_vars` are bound local variables. This in turn allows the caller to control the exact closure (i.e. non-global free variables) for the inner factory. The AST is expected to define some symbol named by `entity_name`. Args: nodes: ast.AST entity_name: Union[Text, ast.AST] factory_factory_name: Text factory_name: Text closure_vars: Iterable[Text] future_features: Iterable[Text], see EntityInfo.future_features. Returns: ast.AST """ if not isinstance(nodes, (list, tuple)): nodes = (nodes, ) dummy_closure_defs = [] for var_name in closure_vars: template = """ var_name = None """ dummy_closure_defs.extend( templates.replace(template, var_name=var_name)) if future_features: future_imports = gast.ImportFrom(module='__future__', names=[ gast.alias(name=name, asname=None) for name in future_features ], level=0) else: future_imports = [] # These dummy symbol declarations create local fariables in a function scope, # so that the Python parser correctly marks them as free non-global variables # upon load (that is, it creates cell slots for each symbol). Their values are # not used, as the cells are swapped with the original entity's cells after # the code has been loaded. template = """ future_imports def factory_factory_name(): dummy_closure_defs def factory_name(ag__, ag_source_map__, ag_module__): entity_defs entity_name.ag_source_map = ag_source_map__ entity_name.ag_module = ag_module__ entity_name.autograph_info__ = {} return entity_name return factory_name """ return templates.replace(template, future_imports=future_imports, factory_factory_name=factory_factory_name, factory_name=factory_name, dummy_closure_defs=dummy_closure_defs, entity_defs=nodes, entity_name=entity_name)
def _wrap_into_factory(nodes, entity_name, inner_factory_name, outer_factory_name, closure_vars, factory_args, future_features): """Wraps an AST into the body of a factory with consistent lexical context. The AST is expected to define some symbol with a name given by `entity_name`. This mechanism ensures that the resulting transformed entity has lexical scoping identical to that of the source entity, while allowing extra parametrization. Two nested factories achieve the following: 1. The inner factory dynamically creates the entity represented by `nodes`. 2. The inner factory is parametrized by a custom set of arguments. 3. The inner factory has a closure identical to that of the transformed entity. 4. The inner factory has local variables named like `args`, which `nodes` may use as additional parameters. 5. The inner factory returns the variables given by `entity_name`. 6. The outer factory is niladic. 7. The outer factory has no closure. 8. The outer factory creates the necessary lexical scope for the inner factory, so that the loaded code has the given configuration for closure/globals. 9. The outer factory returns the inner factory. Roughly speaking, the following code is generated: from __future__ import future_feature_1 from __future__ import future_feature_2 ... def outer_factory(): closure_var_1 = None closure_var_2 = None ... def inner_factory(arg_1, arg_2, ...): <<nodes>> return entity return inner_factory The lexical scoping is created using dummy symbol declarations which create local fariables in the body of the outer factory, so that the Python parser correctly marks them as free non-global variables upon load (that is, it creates cell slots for each symbol. Thes symbols are initialized with None, but their values are not expected to be used; instead, the caller is expected to replace them with the cells of the source entity. For more details, see: https://docs.python.org/3/reference/executionmodel.html#binding-of-names Args: nodes: Tuple[ast.AST], the source code to wrap. entity_name: Union[Text, ast.AST], the name of the principal entity that `nodes` define. inner_factory_name: Text, the name of the inner factory. outer_factory_name: Text, the name of the outer factory. closure_vars: Iterable[Text], names of the closure variables for the inner factory. factory_args: Iterable[Text], names of additional arguments for the inner factory. Useful to configure variables that the converted code can use. Typically, these are modules. future_features: Iterable[Text], names of future statements to associate the code with. Returns: ast.AST """ dummy_closure_defs = [] for var_name in closure_vars: template = """ var_name = None """ dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) if future_features: future_imports = gast.ImportFrom( module='__future__', names=[gast.alias(name=name, asname=None) for name in future_features], level=0) else: future_imports = [] factory_args = [ gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) for name in factory_args ] template = """ future_imports def outer_factory_name(): dummy_closure_defs def inner_factory_name(factory_args): entity_defs return entity_name return inner_factory_name """ return templates.replace( template, dummy_closure_defs=dummy_closure_defs, entity_defs=nodes, entity_name=entity_name, factory_args=factory_args, future_imports=future_imports, inner_factory_name=inner_factory_name, outer_factory_name=outer_factory_name)
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue nodes, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, do_rename=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = nodes[0] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace