def test_find_matching_definitions_lambda_uses_arg_names(self): node = parser.parse_str( textwrap.dedent(""" f = lambda x: 1, lambda y: 2 """)) f = lambda x: x nodes = ast_util.find_matching_definitions(node, f) self.assertLambdaNodes(nodes, ('(1)',)) f = lambda y: y nodes = ast_util.find_matching_definitions(node, f) self.assertLambdaNodes(nodes, ('(2)',))
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses # regex matching to adjust the exact location around the line number that # CPython records. This is particularly problematic for lambda functions, # where the entire containing lines are returned. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\n. This is an extremely rare occurrence. Please report it to' ' the TensorFlow team.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def test_find_matching_definitions_lambda(self): node = parser.parse(textwrap.dedent(""" f = lambda x: 1 """)) f = lambda x: x nodes = ast_util.find_matching_definitions(node, f) self.assertLambdaNodes(nodes, ('(1)', ))
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve_entity(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: new_name = namer.function_name(f.__name__) else: new_name = f.__name__ entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) elif do_rename: node.name = new_name else: assert node.name == new_name return (node, ), new_name, entity_info
def test_find_matching_definitions_lambda_multiple_matches(self): node = parser.parse_str( textwrap.dedent(""" f = lambda x: 1, lambda x: 2 """)) f = lambda x: x nodes = ast_util.find_matching_definitions(node, f) self.assertLambdaNodes(nodes, ('(1)', '(2)'))
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) node = node.body[0] # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return [node], new_name, namespace
def test_find_matching_definitions_function(self): node = parser.parse_str( textwrap.dedent(""" def f(x): return 1 """)) def f(x): return x nodes = ast_util.find_matching_definitions(node, f) self.assertFunctionDefNodes(nodes, ('return 1', ))
def _transform_function(self, fn, user_context): """Performs source code transformation on a function.""" future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) # In general, the output of inspect.getsource is inexact for lambdas # because it uses regex matching to adjust the exact location around # the line number that CPython records. Then, the entire containing line # is returned, which we may have trouble disambiguating. # For example: # x, y = lambda: 1, lambda: 2 is_lambda = fn.__name__ == '<lambda>' if is_lambda: nodes = ast_util.find_matching_definitions(node, fn) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}.' ' It was defined in this code:\n' '{}\n' 'This code must contain a single distinguishable lambda.' ' To avoid this problem, define each lambda in a separate' ' expression.'.format(fn, source)) node, = nodes origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo(name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if is_lambda: node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def test_find_matching_definitions_decorated_compatible(self): node = parser.parse_str( textwrap.dedent(""" @sneaky_decorator def f(x, *args, **kwargs): return 1 """)) def f(a, b, c, d=1): return a + b + c + d nodes = ast_util.find_matching_definitions(node, f) self.assertFunctionDefNodes(nodes, ('return 1', ))
def test_find_matching_definitions_nested_functions_same_name(self): node = parser.parse_str( textwrap.dedent(""" def f(x, *args, **kwargs): def f(x, y): return 1 return 2 """)) def f(x, y): return x + y nodes = ast_util.find_matching_definitions(node, f) self.assertFunctionDefNodes(nodes, ('return 1', ))
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses crude # regex matching methods to search the source file. This is particularly # problematic for lambda functions, where the entire containing lines are # returned. Certain distributions of CPython may also return the enclosing # function for local functions. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: # The inspect.getsource bug is currently known to occur in the Windows # integration tests which run Python 3.6. # TODO(mdan): Find out eaxctly which distribution of Python is that. raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\nTo avoid ambiguity, use a unique name for each' ' function.\nNote that some distributions of Python may report source' ' code incorrectly. It may be possible to avoid that bug by' ' organizing the code into smaller units (smaller files, functions or' ' classes), or by turning AutoGraph off.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace