def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): def visit_If(self, node): self.enter_local_scope() return self.generic_visit(node) def visit_For(self, node): node = self.generic_visit(node) self.exit_local_scope() return node tr = TestTransformer(self._simple_context()) def no_exit(a): if a > 0: print(a) return None node, _ = parser.parse_entity(no_exit, future_features=()) with self.assertRaises(AssertionError): tr.visit(node) def no_entry(a): for _ in a: print(a) node, _ = parser.parse_entity(no_entry, future_features=()) with self.assertRaises(AssertionError): tr.visit(node)
def test_parse_comments(self): def f(): # unindented comment pass with self.assertRaises(ValueError): parser.parse_entity(f, future_features=())
def test_parse_multiline_strings(self): def f(): print(""" some multiline string""") with self.assertRaises(ValueError): parser.parse_entity(f)
def test_origin_info_preserved_in_moved_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return node.body tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 x += 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source, 'test_file', 100, 0) node = tr.visit(node) assign_node = node.body[1] aug_assign_node = node.body[2] # Keep their original line numbers. self.assertEqual( anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 103) self.assertEqual( anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 104)
def test_robust_error_on_ast_corruption(self): # A child class should not be able to be so broken that it causes the error # handling in `transformer.Base` to raise an exception. Why not? Because # then the original error location is dropped, and an error handler higher # up in the call stack gives misleading information. # Here we test that the error handling in `visit` completes, and blames the # correct original exception, even if the AST gets corrupted. class NotANode(object): pass class BrokenTransformer(transformer.Base): def visit_If(self, node): node.body = NotANode() raise ValueError('I blew up') def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_context()) node, _ = parser.parse_entity(test_function, future_features=()) with self.assertRaises(ValueError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler. expected_substring = 'I blew up' self.assertTrue(expected_substring in obtained_message, obtained_message)
def test_parse_comments(self): def f(): # unindented comment pass node, _ = parser.parse_entity(f, future_features=()) self.assertEqual('f', node.name)
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError( 'Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def test_ext_slice_roundtrip(self): def ext_slice(n): return n[:, :], n[0, :], n[:, 0] node, _ = parser.parse_entity(ext_slice, future_features=()) source = parser.unparse(node) self.assertAstMatches(node, source, expr=False)
def test_parse_entity_print_function(self): def f(x): print(x) node, _ = parser.parse_entity(f, future_features=('print_function',)) self.assertEqual('f', node.name)
def test_parse_entity(self): def f(x): return x + 1 node, _ = parser.parse_entity(f, future_features=()) self.assertEqual('f', node.name)
def test_parse_lambda_complex_body(self): l = lambda x: ( # pylint:disable=g-long-lambda x.y( [], x.z, (), x[0:2], ), x.u, 'abc', 1, ) node, source = parser.parse_entity(l, future_features=()) expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)" self.assertAstMatches(node, expected_node_src) base_source = ('lambda x: ( # pylint:disable=g-long-lambda\n' ' x.y(\n' ' [],\n' ' x.z,\n' ' (),\n' ' x[0:2],\n' ' ),\n' ' x.u,\n' ' \'abc\',\n' ' 1,') # The complete source includes the trailing parenthesis. But that is only # detected in runtimes which correctly track end_lineno for ASTs. self.assertMatchesWithPotentialGarbage(source, base_source, '\n )')
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_state_tracking_context_manager(self): class CondState(object): pass class TestTransformer(transformer.Base): def visit(self, node): anno.setanno(node, 'cond_state', self.state[CondState].value) return super(TestTransformer, self).visit(node) def visit_If(self, node): with self.state[CondState]: return self.generic_visit(node) tr = TestTransformer(self._simple_context()) def test_function(a): a = 1 if a > 2: _ = 'b' if a < 5: _ = 'c' _ = 'd' node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) fn_body = node.body outer_if_body = fn_body[1].body self.assertDifferentAnno(fn_body[0], outer_if_body[0], 'cond_state') self.assertSameAnno(outer_if_body[0], outer_if_body[2], 'cond_state') inner_if_body = outer_if_body[1].body self.assertDifferentAnno(inner_if_body[0], outer_if_body[0], 'cond_state')
def test_origin_info_preserved_in_moved_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return node.body tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 x += 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source) node = tr.visit(node) assign_node = node.body[1] aug_assign_node = node.body[2] self.assertEqual( anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4) self.assertEqual( anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
def assert_body_anfs_as_expected(self, expected_fn, test_fn, config=None): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. exp_node, _ = parser.parse_entity(expected_fn, future_features=()) node, _ = parser.parse_entity(test_fn, future_features=()) node = anf.transform(node, self._simple_context(), config=config) exp_name = exp_node.name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope # at the same time). node.name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent node_repeated = anf.transform(node, self._simple_context()) self.assert_same_ast(node_repeated, node)
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, _, source = parser.parse_entity(test_fn, future_imports=()) origin_info.resolve(node, source) origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_entity = self._try_resolve_target(node.func) if target_entity is not None: # This may be reached when "calling" a callable attribute of an object. # For example: # # self.fc = tf.keras.layers.Dense() # self.fc() # for mod in self.ctx.program.uncompiled_modules: if target_entity.__module__.startswith(mod[0] + '.'): return False # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. if hasattr(target_entity, '__pyct_is_compile_decorator'): return False if target_entity in self.ctx.program.options.strip_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.options.strip_decorators): return False return True
def test_visit_block_postprocessing(self): class TestTransformer(transformer.Base): def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) return if_node, if_node.body return node, None def visit_FunctionDef(self, node): node.body = self.visit_block( node.body, after_visit=self._process_body_item) return node def test_function(x, y): z = x z = y return z tr = TestTransformer(self._simple_context()) node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) self.assertEqual(len(node.body), 2) self.assertTrue(isinstance(node.body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1], gast.If)) self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
def test_robust_error_on_list_visit(self): class BrokenTransformer(transformer.Base): def visit_If(self, node): # This is broken because visit expects a single node, not a list, and # the body of an if is a list. # Importantly, the default error handling in visit also expects a single # node. Therefore, mistakes like this need to trigger a type error # before the visit called here installs its error handler. # That type error can then be caught by the enclosing call to visit, # and correctly blame the If node. self.visit(node.body) return node def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_context()) node, _ = parser.parse_entity(test_function, future_features=()) with self.assertRaises(ValueError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message)
def test_parse_entity(self): def f(x): return x + 1 mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name)
def disabled_test_resolve_with_future_imports(self): def test_fn(x): """Docstring.""" print(x) return x # comment node, source, _ = parser.parse_entity(test_fn) origin_info.resolve(node, source) origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[2], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 5) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_visit_block_postprocessing(self): class TestTransformer(transformer.Base): def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If( gast.Name( 'x', ctx=gast.Load(), annotation=None, type_comment=None), [node], []) return if_node, if_node.body return node, None def visit_FunctionDef(self, node): node.body = self.visit_block( node.body, after_visit=self._process_body_item) return node def test_function(x, y): z = x z = y return z tr = TestTransformer(self._simple_context()) node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) self.assertEqual(len(node.body), 2) self.assertIsInstance(node.body[0], gast.Assign) self.assertIsInstance(node.body[1], gast.If) self.assertIsInstance(node.body[1].body[0], gast.Assign) self.assertIsInstance(node.body[1].body[1], gast.Return)
def test_robust_error_on_list_visit(self): class BrokenTransformer(transformer.Base): def visit_If(self, node): # This is broken because visit expects a single node, not a list, and # the body of an if is a list. # Importantly, the default error handling in visit also expects a single # node. Therefore, mistakes like this need to trigger a type error # before the visit called here installs its error handler. # That type error can then be caught by the enclosing call to visit, # and correctly blame the If node. self.visit(node.body) return node def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_context()) node, _ = parser.parse_entity(test_function, future_features=()) with self.assertRaises(ValueError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message) # The exception should point at the if statement, not any place else. Could # also check the stack trace. self.assertTrue('Occurred at node:\nIf' in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nFunctionDef' not in obtained_message, obtained_message) self.assertTrue('Occurred at node:\nReturn' not in obtained_message, obtained_message)
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators, verbose=True), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators, verbose=True), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_resolve_entity(self): test_fn = basic_definitions.simple_function node, source = parser.parse_entity( test_fn, inspect_utils.getfutureimports(test_fn)) origin_info.resolve_entity(node, source, test_fn) # The line numbers below should match those in basic_definitions.py fn_start = inspect.getsourcelines(test_fn)[1] def_origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(def_origin.loc.lineno, fn_start) self.assertEqual(def_origin.loc.col_offset, 0) self.assertEqual(def_origin.source_code_line, 'def simple_function(x):') self.assertIsNone(def_origin.comment) docstring_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(docstring_origin.loc.lineno, fn_start + 1) self.assertEqual(docstring_origin.loc.col_offset, 2) self.assertEqual(docstring_origin.source_code_line, ' """Docstring."""') self.assertIsNone(docstring_origin.comment) ret_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(ret_origin.loc.lineno, fn_start + 2) self.assertEqual(ret_origin.loc.col_offset, 2) self.assertEqual(ret_origin.source_code_line, ' return x # comment') self.assertEqual(ret_origin.comment, 'comment')
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve_entity(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: new_name = namer.function_name(f.__name__) else: new_name = f.__name__ entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) elif do_rename: node.name = new_name else: assert node.name == new_name return (node, ), new_name, entity_info
def test_parse_lambda_prefix_cleanup(self): lambda_lam = lambda x: x + 1 expected_node_src = 'lambda x: (x + 1)' node, source = parser.parse_entity(lambda_lam, future_features=()) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_basic(self): def test_function(): a = 0 return a node, _, _ = parser.parse_entity(test_function, future_imports=()) node = anf.transform(node, self._simple_context()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_parse_multiline_strings(self): def f(): print(""" multiline string""") node, _ = parser.parse_entity(f, future_features=()) self.assertEqual('f', node.name)
def do_parse_and_test(lam, **unused_kwargs): node, source = parser.parse_entity(lam, future_features=()) self.assertEqual( parser.unparse(node, include_encoding_marker=False), '(lambda x: x)') self.assertMatchesWithPotentialGarbage(source, 'lambda x: x', ', named_arg=1)')
def test_parse_lambda_prefix_cleanup(self): lambda_lam = lambda x: x + 1 node, source = parser.parse_entity(lambda_lam, future_features=()) self.assertEqual(parser.unparse(node, include_encoding_marker=False), '(lambda x: (x + 1))') self.assertEqual(source, 'lambda x: x + 1')
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_context()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) test_function_node = node test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual((test_function_node, ), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function, future_features=()) node = anf.transform(node, self._simple_context()) result, _, _ = loader.load_ast(node) self.assertEqual(test_function(), result.test_function())
def test_parse_lambda_resolution_ambiguous(self): l = lambda x: lambda x: 2 * x expected_exception_text = re.compile( r'found multiple definitions' r'.+' r'\(lambda x: \(lambda x' r'.+' r'\(lambda x: \(2', re.DOTALL) with self.assertRaisesRegex(errors.UnsupportedLanguageElementError, expected_exception_text): parser.parse_entity(l, future_features=()) with self.assertRaisesRegex(errors.UnsupportedLanguageElementError, expected_exception_text): parser.parse_entity(l(0), future_features=())
def assert_body_anfs_as_expected(self, expected_fn, test_fn): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. exp_node, _ = parser.parse_entity(expected_fn) node, _ = parser.parse_entity(test_fn) node = anf.transform( node, self._simple_source_info(), gensym_source=DummyGensym) exp_name = exp_node.body[0].name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope # at the same time). node.body[0].name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent node_repeated = anf.transform( node, self._simple_source_info(), gensym_source=DummyGensym) self.assert_same_ast(node_repeated, node)
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False target_entity = self._try_resolve_target(node.func) if target_entity is not None: # Currently, lambdas are always converted. # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...) if inspect_utils.islambda(target_entity): return True # This may be reached when "calling" a callable attribute of an object. # For example: # # self.fc = tf.keras.layers.Dense() # self.fc() # for mod in self.ctx.program.uncompiled_modules: if target_entity.__module__.startswith(mod[0] + '.'): return False # Inspect the target function decorators. If any include a @convert # or @do_not_convert annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert? # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True # This attribute is set when the decorator was applied before the # function was parsed. See api.py. if hasattr(target_entity, '__ag_compiled'): return False for dec in target_node.decorator_list: decorator_fn = self._resolve_decorator_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.options.strip_decorators): return False return True
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(mdan): Can we convert everything and scoop the lambda afterwards? if f.__name__ == '<lambda>': nodes = ast_util.find_matching_lambda_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which contains multiple lambdas with' ' identical argument names. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return (node,), new_name, entity_info
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_context()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) test_function_node = node test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual( (test_function_node,), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) return node
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) return node, entity_info
def test_source_map_no_origin(self): def test_fn(x): return x + 1 node, _ = parser.parse_entity(test_fn) fn_node = node.body[0] converted_code = compiler.ast_to_source(fn_node) source_map = origin_info.create_source_map( fn_node, converted_code, 'test_filename', [0]) self.assertEqual(len(source_map), 0)
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( compiler.ast_to_object( parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def test_parser_compile_identity(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b node, _ = parser.parse_entity(test_fn, future_features=()) module, _, _ = compiler.ast_to_object(node) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource(module.test_fn))
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) graphs = cfg.build(node) node = reaching_definitions.resolve(node, entity_info, graphs, reaching_definitions.Definition) return node
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) return node
def test_invalid_default(self): def invalid_directive(valid_arg, invalid_default=object()): del valid_arg del invalid_default return def call_invalid_directive(): invalid_directive(1) node, _ = parser.parse_entity(call_invalid_directive, ()) # Find the call to the invalid directive node = node.body[0].value with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): directives_converter._map_args(node, invalid_directive)
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._simple_context()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) for_node = node.body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_create_source_map(self): def test_fn(x): return x + 1 node, _ = parser.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) fn_node = node.body[0] anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) converted_code = compiler.ast_to_source(fn_node) source_map = origin_info.create_source_map( fn_node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)