def test_resolve(self): source = """ def test_fn(x): '''Docstring.''' return x # comment """ source = textwrap.dedent(source) node = parser.parse_str(source) origin_info.resolve(node, source) origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, " '''Docstring.'''") self.assertIsNone(origin.comment) origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 4) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators, verbose=True), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_resolve(self): source = """ def test_fn(x): '''Docstring.''' return x # comment """ source = textwrap.dedent(source) node = parser.parse(source) origin_info.resolve(node, source, 'test_file', 10, 10) def_origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(def_origin.loc.filename, 'test_file') self.assertEqual(def_origin.loc.lineno, 10) self.assertEqual(def_origin.loc.col_offset, 10) self.assertEqual(def_origin.source_code_line, 'def test_fn(x):') self.assertIsNone(def_origin.comment) docstring_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(def_origin.loc.filename, 'test_file') self.assertEqual(docstring_origin.loc.lineno, 11) self.assertEqual(docstring_origin.loc.col_offset, 12) self.assertEqual(docstring_origin.source_code_line, " '''Docstring.'''") self.assertIsNone(docstring_origin.comment) ret_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(def_origin.loc.filename, 'test_file') self.assertEqual(ret_origin.loc.lineno, 12) self.assertEqual(ret_origin.loc.col_offset, 12) self.assertEqual(ret_origin.source_code_line, ' return x # comment') self.assertEqual(ret_origin.comment, 'comment')
def test_origin_info_preserved_in_moved_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return node.body tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 x += 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source) node = tr.visit(node) assign_node = node.body[1] aug_assign_node = node.body[2] self.assertEqual( anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4) self.assertEqual( anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
def disabled_test_resolve_with_future_imports(self): def test_fn(x): """Docstring.""" print(x) return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[-1] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[2], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 5) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_origin_info_preserved_in_moved_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return node.body tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 x += 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source, 'test_file', 100, 0) node = tr.visit(node) assign_node = node.body[1] aug_assign_node = node.body[2] # Keep their original line numbers. self.assertEqual( anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 103) self.assertEqual( anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 104)
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError( 'Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses # regex matching to adjust the exact location around the line number that # CPython records. This is particularly problematic for lambda functions, # where the entire containing lines are returned. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\n. This is an extremely rare occurrence. Please report it to' ' the TensorFlow team.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def test_resolve(self): source = """ def test_fn(x): '''Docstring.''' return x # comment """ source = textwrap.dedent(source) node = parser.parse_str(source) origin_info.resolve(node, source) origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, " '''Docstring.'''") self.assertIsNone(origin.comment) origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 4) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators, verbose=True), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace, arg_values=arg_values, arg_types=arg_types) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return (node, ), new_name, entity_info
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(mdan): Can we convert everything and scoop the lambda afterwards? if f.__name__ == '<lambda>': nodes = ast_util.find_matching_lambda_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which contains multiple lambdas with' ' identical argument names. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return (node,), new_name, entity_info
def test_resolve_with_trailing_garbage(self): # This comment will be missed because the tokenizer fails to reach it. source = ' lambda: foo([], bar=1)), baz=2)()' clean_source = 'lambda: foo([], bar=1)' node = parser.parse(clean_source).value origin_info.resolve(node, source, 'test_file', 10, 10) def_origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(def_origin.loc.lineno, 10) self.assertEqual(def_origin.loc.col_offset, 10) self.assertEqual(def_origin.source_code_line, source) self.assertIsNone(def_origin.comment)
def test_basic_codegen(self): class TestCodegen(transformer.CodeGenerator): def visit_Assign(self, node): self.emit(parser.unparse(node, include_encoding_marker=False)) self.emit('\n') def visit_Return(self, node): self.emit(parser.unparse(node, include_encoding_marker=False)) self.emit('\n') def visit_If(self, node): self.emit('if ') # This is just for simplifity. A real generator will walk the tree and # emit proper code. self.emit(parser.unparse(node.test, include_encoding_marker=False)) self.emit(' {\n') self.visit_block(node.body) self.emit('} else {\n') self.visit_block(node.orelse) self.emit('}\n') tg = TestCodegen(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 2 if x > 1: x = 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source, 'test_file', 100, 0) tg.visit(node) self.assertEqual( tg.code_buffer, '\n'.join([ 'x = 1', 'if (x > 0) {', 'x = 2', 'if (x > 1) {', 'x = 3', '} else {', '}', '} else {', '}', 'return x', '', ]))
def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions node, source, _ = parser.parse_entity(test_fn) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_source_map(self): def test_fn(x): if x > 0: x += 1 return x node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) # Insert a traced line. new_node = parser.parse_str('x = abs(x)').body[0] anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) fn_node.body.insert(0, new_node) # Insert an untraced line. fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) modified_source = compiler.ast_to_source(fn_node) source_map = origin_info.source_map(fn_node, modified_source, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 1) origin = source_map[loc] self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertEqual(origin.loc.lineno, 1) # The untraced line, inserted second. loc = origin_info.LineLocation('test_filename', 2) self.assertFalse(loc in source_map) # The traced line, inserted first. loc = origin_info.LineLocation('test_filename', 3) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2) loc = origin_info.LineLocation('test_filename', 4) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2)
def test_origin_info_propagated_to_new_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return gast.Pass() tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source, 'test_file', 100, 0) node = tr.visit(node) created_pass_node = node.body[1] # Takes the line number of the if statement. self.assertEqual( anno.getanno(created_pass_node, anno.Basic.ORIGIN).loc.lineno, 102)
def test_origin_info_propagated_to_new_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return gast.Pass() tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source) node = tr.visit(node) created_pass_node = node.body[1] self.assertEqual( anno.getanno(created_pass_node, anno.Basic.ORIGIN).loc.lineno, 3)
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None, rewrite_errors=True): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context, rewrite_errors=rewrite_errors) # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses crude # regex matching methods to search the source file. This is particularly # problematic for lambda functions, where the entire containing lines are # returned. Certain distributions of CPython may also return the enclosing # function for local functions. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\nTo avoid ambiguity, use a unique name for each' ' function.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses crude # regex matching methods to search the source file. This is particularly # problematic for lambda functions, where the entire containing lines are # returned. Certain distributions of CPython may also return the enclosing # function for local functions. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: # The inspect.getsource bug is currently known to occur in the Windows # integration tests which run Python 3.6. # TODO(mdan): Find out eaxctly which distribution of Python is that. raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\nTo avoid ambiguity, use a unique name for each' ' function.\nNote that some distributions of Python may report source' ' code incorrectly. It may be possible to avoid that bug by' ' organizing the code into smaller units (smaller files, functions or' ' classes), or by turning AutoGraph off.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace