def test_new_symbol_avoids_conflicts(self): namer = naming.Namer({'temp': 1}, True, None, ()) # temp is reserved in the global namespace self.assertEqual('temp_1', namer.new_symbol('temp', set())) # temp_2 is reserved in the local namespace self.assertEqual('temp_3', namer.new_symbol('temp', set(('temp_2', )))) self.assertItemsEqual(('temp_1', 'temp_3'), namer.generated_names)
def test_compiled_function_name_avoids_global_conflicts(self): def foo(): pass namer = naming.Namer({'tf__foo': 1}, True, None, ()) self.assertEqual(('tf__foo_1', True), namer.compiled_function_name('foo', foo))
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve_entity(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: new_name = namer.function_name(f.__name__) else: new_name = f.__name__ entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) elif do_rename: node.name = new_name else: assert node.name == new_name return (node, ), new_name, entity_info
def test_compiled_function_name_consistent(self): def foo(): pass namer = naming.Namer({}, True, None, ()) self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo', foo)) self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo', foo))
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) node = node.body[0] # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return [node], new_name, namespace
def test_compiled_function_name_tracks_names(self): def bar(): pass namer = naming.Namer({}, True, None, ()) self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo')) self.assertEqual(('tf__bar', True), namer.compiled_function_name('bar', bar)) self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls) self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): """Returns a (possibly cached) factory for the converted result of entity.""" # The cache key is the entity's code object if it defined one, otherwise it's # the entity itself. Keying by the code object allows caching of functions # that are dynamically created e.g. in a loop. if hasattr(entity, '__code__'): key = entity.__code__ else: key = entity # The cache subkey encompases any conversion options on which the generated # code may depend. # The cached factory includes the necessary definitions to distinguish # between the global and non-global free variables. For this reason, the # cache subkey includes the names of the free non-globals. subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) with _CACHE_LOCK: # The cache values are _ConvertedEntityFactoryInfo objects. if _CACHE.has(key, subkey): # TODO(mdan): Check whether the module is still loaded. converted_entity_info = _CACHE[key][subkey] logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity, key, subkey, converted_entity_info) return converted_entity_info logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key, subkey) nodes, converted_name, entity_info = convert_entity_to_ast( entity, program_ctx) namer = naming.Namer(entity_info.namespace) factory_factory_name = namer.new_symbol('create_converted_entity_factory', ()) factory_name = namer.new_symbol('create_converted_entity', ()) nodes = _wrap_into_dynamic_factory(nodes, converted_name, factory_factory_name, factory_name, free_nonglobal_var_names, entity_info.future_features) module, _, source_map = compiler.ast_to_object( nodes, include_source_map=True) module_name = module.__name__ converted_entity_info = _ConvertedEntityFactoryInfo( module_name=module_name, converted_name=converted_name, factory_factory_name=factory_factory_name, source_map=source_map) _CACHE[key][subkey] = converted_entity_info return converted_entity_info
def _parse_and_analyze(f): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF node = anf.transform(node, ctx) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions node, source, _ = parser.parse_entity(test_fn) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_function_name_unsanitized_fqn(self): namer = naming.Namer({}) self.assertEqual('tf__foo_bar', namer.function_name('foo.bar')) self.assertEqual('tf__foo_bar_baz', namer.function_name(('foo.bar', 'baz')))
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx
def test_function_name_consistent(self): namer = naming.Namer({}) self.assertEqual('tf__foo', namer.function_name('foo')) self.assertEqual('tf__foo', namer.function_name('foo'))
def test_new_symbol_avoids_duplicates(self): namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertEqual('temp_1', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names)
def test_compiled_class_name_unsanitized_fqn(self): namer = naming.Namer({}, True, None, ()) self.assertEqual('TfFooBarBaz', namer.compiled_class_name(('foo.bar', 'Baz')))
def new_namer(self, namespace): return naming.Namer(namespace, self.options.recursive, self.name_map, self.partial_types)
def test_compiled_class_name_basic(self): namer = naming.Namer({}, True, None, ()) self.assertEqual('TfFooBar', namer.compiled_class_name(('foo', 'Bar')))
def test_compiled_function_name_unsanitized_fqn(self): namer = naming.Namer({}, True, None, ()) self.assertEqual(('tf__foo_bar', True), namer.compiled_function_name('foo.bar')) self.assertEqual(('tf__foo_bar_baz', True), namer.compiled_function_name(('foo.bar', 'baz')))
def test_class_name_unsanitized_fqn(self): namer = naming.Namer({}) self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz')))
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue nodes, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, do_rename=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = nodes[0] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def test_class_name_basic(self): namer = naming.Namer({}) self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar')))
def test_new_symbol_tracks_names(self): namer = naming.Namer({}, True, None, ()) self.assertEqual('temp', namer.new_symbol('temp', set())) self.assertItemsEqual(('temp', ), namer.generated_names)
def test_function_name_avoids_global_conflicts(self): namer = naming.Namer({'tf__foo': 1}) self.assertEqual('tf__foo_1', namer.function_name('foo'))
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def test_function_name_tracks_names(self): namer = naming.Namer({}) self.assertEqual('tf__foo', namer.function_name('foo')) self.assertEqual('tf__bar', namer.function_name('bar')) self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)