def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): def visit_If(self, node): self.enter_local_scope() return self.generic_visit(node) def visit_For(self, node): node = self.generic_visit(node) self.exit_local_scope() return node tr = TestTransformer(self._simple_source_info()) def no_exit(a): if a > 0: print(a) return None node, _ = parser.parse_entity(no_exit) with self.assertRaises(AssertionError): tr.visit(node) def no_entry(a): for _ in a: print(a) node, _ = parser.parse_entity(no_entry) with self.assertRaises(AssertionError): tr.visit(node)
def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): def visit_If(self, node): self.enter_local_scope() return self.generic_visit(node) def visit_For(self, node): node = self.generic_visit(node) self.exit_local_scope() return node tr = TestTransformer(self._context_for_testing()) def no_exit(a): if a > 0: print(a) return None node, _ = parser.parse_entity(no_exit) with self.assertRaises(AssertionError): tr.visit(node) def no_entry(a): for _ in a: print(a) node, _ = parser.parse_entity(no_entry) with self.assertRaises(AssertionError): tr.visit(node)
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError( 'Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return node, new_name, namespace
def test_visit_block_postprocessing(self): class TestTransformer(transformer.Base): def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) return if_node, if_node.body return node, None def visit_FunctionDef(self, node): node.body = self.visit_block( node.body, after_visit=self._process_body_item) return node def test_function(x, y): z = x z = y return z tr = TestTransformer(self._context_for_testing()) node, _ = parser.parse_entity(test_function) node = tr.visit(node) node = node.body[0] self.assertEqual(len(node.body), 2) self.assertTrue(isinstance(node.body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1], gast.If)) self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
def test_parse_entity(self): def f(x): return x + 1 mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name)
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_robust_error_on_list_visit(self): class BrokenTransformer(transformer.Base): def visit_If(self, node): # This is broken because visit expects a single node, not a list, and # the body of an if is a list. # Importantly, the default error handling in visit also expects a single # node. Therefore, mistakes like this need to trigger a type error # before the visit called here installs its error handler. # That type error can then be caught by the enclosing call to visit, # and correctly blame the If node. self.visit(node.body) return node def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) with self.assertRaises(transformer.AutographParseError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message) # The exception should point at the if statement, not any place else. Could # also check the stack trace. self.assertTrue('Occurred at node:\nIf' in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nFunctionDef' not in obtained_message, obtained_message) self.assertTrue('Occurred at node:\nReturn' not in obtained_message, obtained_message)
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, conversion_map.api_module) namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type, recursive=conversion_map.recursive, type_annotation_func=type_hints.set_element_type) node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name conversion_map.update_name_map(namer) # TODO(mdan): Use this at compilation. conversion_map.additional_imports.update(deps) return node, new_name, namespace
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return node, new_name, namespace
def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, conversion_map.api_module) namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type, recursive=conversion_map.recursive, type_annotation_func=type_hints.set_element_type) node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name conversion_map.update_name_map(namer) # TODO(mdan): Use this at compilation. conversion_map.additional_imports.update(deps) return node, new_name
def parse_and_analyze(self, test_fn, namespace, namer=None, arg_types=None, include_type_analysis=True, owner_type=None, recursive=True): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=namer or FakeNamer(), source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type, recursive=recursive, type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) if include_type_analysis: node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) self.ctx = ctx return node
def parse_and_analyze(self, test_fn, namespace, namer=None, arg_types=None, include_type_analysis=True, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) if include_type_analysis: node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_visit_block_postprocessing(self): class TestTransformer(transformer.Base): def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) return if_node, if_node.body return node, None def visit_FunctionDef(self, node): node.body = self.visit_block( node.body, after_visit=self._process_body_item) return node def test_function(x, y): z = x z = y return z tr = TestTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) node = tr.visit(node) node = node.body[0] self.assertEqual(len(node.body), 2) self.assertTrue(isinstance(node.body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1], gast.If)) self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
def test_basic(self): def test_function(): a = 0 return a node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function())
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual((test_function_node, ), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def assert_body_anfs_as_expected(self, expected_fn, test_fn): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. exp_node, _ = parser.parse_entity(expected_fn) node, _ = parser.parse_entity(test_fn) node = anf.transform( node, self._simple_source_info(), gensym_source=DummyGensym) exp_name = exp_node.body[0].name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope # at the same time). node.body[0].name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent node_repeated = anf.transform( node, self._simple_source_info(), gensym_source=DummyGensym) self.assert_same_ast(node_repeated, node)
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual( (test_function_node,), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) return node, entity_info
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) return node, entity_info
def assert_body_anfs_as_expected(self, expected_fn, test_fn): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. exp_node, _ = parser.parse_entity(expected_fn) node, _ = parser.parse_entity(test_fn) node = anf.transform(node, self._simple_source_info(), gensym_source=DummyGensym) exp_name = exp_node.body[0].name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope # at the same time). node.body[0].name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent node_repeated = anf.transform(node, self._simple_source_info(), gensym_source=DummyGensym) self.assert_same_ast(node_repeated, node)
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) ctx = context.EntityContext(namer=None, source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None, recursive=True) node = qual_names.resolve(node) return node, ctx
def test_parser_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( compiler.ast_to_object( parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO (mdan): Needs cleanup. We should remove the use of fqn altogether. id:489 # https://github.com/imdone/tensorflow/issues/490 module_name = fqn[0] for mod in self.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_entity = self._try_resolve_target(node.func) if target_entity is not None: # This attribute is set by the decorator itself. # TODO (mdan): This may not play nicely with other wrapping decorators. id:480 # https://github.com/imdone/tensorflow/issues/481 if hasattr(target_entity, '__pyct_is_compile_decorator'): return False if target_entity in self.nocompile_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO (mdan): This may be quite heavy. id:952 # https://github.com/imdone/tensorflow/issues/953 # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and decorator_fn in self.nocompile_decorators): return False return True
def test_invalid_default(self): def invalid_directive(valid_arg, invalid_default=object()): del valid_arg del invalid_default return def call_invalid_directive(): invalid_directive(1) node, _ = parser.parse_entity(call_invalid_directive) # Find the call to the invalid directive node = node.body[0].body[0].value with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): directives_converter._map_args(node, invalid_directive)
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None, recursive=True) node = qual_names.resolve(node) return node, ctx
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) graphs = cfg.build(node) node = reaching_definitions.resolve(node, entity_info, graphs, reaching_definitions.Definition) return node
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_entity = self._try_resolve_target(node.func) if target_entity is not None: # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. if hasattr(target_entity, '__pyct_is_compile_decorator'): return False if target_entity in self.ctx.program.autograph_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.autograph_decorators): return False return True
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function) node = tr.visit(node) for_node = node.body[0].body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function) node = tr.visit(node) for_node = node.body[0].body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def test_robust_error_on_ast_corruption(self): # A child class should not be able to be so broken that it causes the error # handling in `transformer.Base` to raise an exception. Why not? Because # then the original error location is dropped, and an error handler higher # up in the call stack gives misleading information. # Here we test that the error handling in `visit` completes, and blames the # correct original exception, even if the AST gets corrupted. class NotANode(object): pass class BrokenTransformer(transformer.Base): def visit_If(self, node): node.body = NotANode() raise ValueError('I blew up') def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) with self.assertRaises(transformer.AutographParseError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler. expected_substring = 'I blew up' self.assertTrue(expected_substring in obtained_message, obtained_message) # Expect the exception to have failed to parse the corrupted AST self.assertTrue( '<could not convert AST to source>' in obtained_message, obtained_message) # The exception should point at the if statement, not any place else. Could # also check the stack trace. self.assertTrue('Occurred at node:\nIf' in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nFunctionDef' not in obtained_message, obtained_message) self.assertTrue('Occurred at node:\nReturn' not in obtained_message, obtained_message)
def test_robust_error_on_ast_corruption(self): # A child class should not be able to be so broken that it causes the error # handling in `transformer.Base` to raise an exception. Why not? Because # then the original error location is dropped, and an error handler higher # up in the call stack gives misleading information. # Here we test that the error handling in `visit` completes, and blames the # correct original exception, even if the AST gets corrupted. class NotANode(object): pass class BrokenTransformer(transformer.Base): def visit_If(self, node): node.body = NotANode() raise ValueError('I blew up') def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) with self.assertRaises(transformer.AutographParseError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler. expected_substring = 'I blew up' self.assertTrue(expected_substring in obtained_message, obtained_message) # Expect the exception to have failed to parse the corrupted AST self.assertTrue( '<could not convert AST to source>' in obtained_message, obtained_message) # The exception should point at the if statement, not any place else. Could # also check the stack trace. self.assertTrue( 'Occurred at node:\nIf' in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nFunctionDef' not in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) return node
def test_source_map(self): def test_fn(x): if x > 0: x += 1 return x node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) # Insert a traced line. new_node = parser.parse_str('x = abs(x)').body[0] anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) fn_node.body.insert(0, new_node) # Insert an untraced line. fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) modified_source = compiler.ast_to_source(fn_node) source_map = origin_info.source_map(fn_node, modified_source, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 1) origin = source_map[loc] self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertEqual(origin.loc.lineno, 1) # The untraced line, inserted second. loc = origin_info.LineLocation('test_filename', 2) self.assertFalse(loc in source_map) # The traced line, inserted first. loc = origin_info.LineLocation('test_filename', 3) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2) loc = origin_info.LineLocation('test_filename', 4) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2)
def _parse_and_analyze(self, test_fn, namespace, literals=None, arg_types=None): literals = literals or {} node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) node = live_values.resolve(node, entity_info, literals) node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, literals) return node
def test_robust_error_on_list_visit(self): class BrokenTransformer(transformer.Base): def visit_If(self, node): # This is broken because visit expects a single node, not a list, and # the body of an if is a list. # Importantly, the default error handling in visit also expects a single # node. Therefore, mistakes like this need to trigger a type error # before the visit called here installs its error handler. # That type error can then be caught by the enclosing call to visit, # and correctly blame the If node. self.visit(node.body) return node def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) with self.assertRaises(transformer.AutographParseError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message) # The exception should point at the if statement, not any place else. Could # also check the stack trace. self.assertTrue( 'Occurred at node:\nIf' in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nFunctionDef' not in obtained_message, obtained_message) self.assertTrue( 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
def parse_and_analyze(self, test_fn, namespace, namer=None, arg_types=None, include_type_analysis=True, owner_type=None, recursive=True, autograph_decorators=()): node, source = parser.parse_entity(test_fn) if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=autograph_decorators, partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) if include_type_analysis: node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node
def _build_cfg(self, fn): node, _ = parser.parse_entity(fn) cfgs = cfg.build(node) return cfgs
def test_state_tracking(self): class LoopState(object): pass class CondState(object): pass class TestTransformer(transformer.Base): def visit(self, node): anno.setanno(node, 'loop_state', self.state[LoopState].value) anno.setanno(node, 'cond_state', self.state[CondState].value) return super(TestTransformer, self).visit(node) def visit_While(self, node): self.state[LoopState].enter() node = self.generic_visit(node) self.state[LoopState].exit() return node def visit_If(self, node): self.state[CondState].enter() node = self.generic_visit(node) self.state[CondState].exit() return node tr = TestTransformer(self._simple_source_info()) def test_function(a): a = 1 while a: _ = 'a' if a > 2: _ = 'b' while True: raise '1' if a > 3: _ = 'c' while True: raise '1' node, _ = parser.parse_entity(test_function) node = tr.visit(node) fn_body = node.body[0].body outer_while_body = fn_body[1].body self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') first_if_body = outer_while_body[1].body self.assertDifferentAnno(outer_while_body[0], first_if_body[0], 'cond_state') self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state') first_inner_while_body = first_if_body[1].body self.assertSameAnno(first_if_body[0], first_inner_while_body[0], 'cond_state') self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0], 'loop_state') second_if_body = outer_while_body[2].body self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state') self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state') second_inner_while_body = second_if_body[1].body self.assertDifferentAnno(first_inner_while_body[0], second_inner_while_body[0], 'cond_state') self.assertDifferentAnno(first_inner_while_body[0], second_inner_while_body[0], 'loop_state')