def test_parse_comments(self): def f(): # unindented comment pass with self.assertRaises(ValueError): parsing.parse_entity(f)
def test_parse_multiline_strings(self): def f(): print(""" some multiline string""") with self.assertRaises(ValueError): parsing.parse_entity(f)
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parsing.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_robust_error_on_ast_corruption(self): # A child class should not be able to be so broken that it causes the error # handling in `transformer.Base` to raise an exception. Why not? Because # then the original error location is dropped, and an error handler higher # up in the call stack gives misleading information. # Here we test that the error handling in `visit` completes, and blames the # correct original exception, even if the AST gets corrupted. class NotANode(object): pass class BrokenTransformer(transformer.Base): def visit_If(self, node): node.body = NotANode() raise ValueError('I blew up') def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parsing.parse_entity(test_function) with self.assertRaises(ValueError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler. expected_substring = 'I blew up' self.assertIn(expected_substring, obtained_message, obtained_message)
def convert(func, overload_module, transformers): """Main entry point for converting a function using Pyct. Args: func: function to be converted overload_module: module containing overloaded functionality transformers: list of transformers to be applied Returns: gen_func: converted function """ source, _ = parsing.parse_entity(func) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) namer = naming.Namer(entity_info.namespace) ctx = transformer.EntityContext(namer, entity_info) overload_name = ctx.namer.new_symbol('overload', set()) overload = config.VirtualizationConfig(overload_module, overload_name) source = _transform(source, ctx, overload, transformers) gen_func = _wrap_in_generator(func, source, namer, overload) gen_func = _attach_closure(func, gen_func) return gen_func
def test_robust_error_on_list_visit(self): class BrokenTransformer(transformer.Base): def visit_If(self, node): # This is broken because visit expects a single node, not a list, and # the body of an if is a list. # Importantly, the default error handling in visit also expects a single # node. Therefore, mistakes like this need to trigger a type error # before the visit called here installs its error handler. # That type error can then be caught by the enclosing call to visit, # and correctly blame the If node. self.visit(node.body) return node def test_function(x): if x > 0: return x tr = BrokenTransformer(self._simple_source_info()) node, _ = parsing.parse_entity(test_function) with self.assertRaises(ValueError) as cm: node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message)
def test_visit_block_postprocessing(self): class TestTransformer(transformer.Base): def _process_body_item(self, node): if isinstance(node, gast.Assign) and (node.value.id == 'y'): if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) return if_node, if_node.body return node, None def visit_FunctionDef(self, node): node.body = self.visit_block( node.body, after_visit=self._process_body_item) return node def test_function(x, y): z = x z = y return z tr = TestTransformer(self._simple_source_info()) node, _ = parsing.parse_entity(test_function) node = tr.visit(node) node = node.body[0] self.assertLen(node.body, 2) self.assertIsInstance(node.body[0], gast.Assign) self.assertIsInstance(node.body[1], gast.If) self.assertIsInstance(node.body[1].body[0], gast.Assign) self.assertIsInstance(node.body[1].body[1], gast.Return)
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parsing.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual((test_function_node, ), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def test_source_map_no_origin(self): def test_fn(x): return x + 1 node, _ = parsing.parse_entity(test_fn) fn_node = node.body[0] converted_code = parsing.ast_to_source(fn_node) source_map = origin_info.create_source_map(fn_node, converted_code, 'test_filename', [0]) self.assertEmpty(source_map)
def _parse_and_analyze(self, test_fn): node, source = parsing.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def test_parsing_compile_idempotent(self): def test_fn(x): a = True b = '' if a: b = x + 1 return b self.assertEqual( textwrap.dedent(inspect.getsource(test_fn)), inspect.getsource( parsing.ast_to_object( parsing.parse_entity(test_fn)[0].body[0])[0].test_fn))
def get_scopes(self, func): source, _ = parsing.parse_entity(func) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) namer = naming.Namer(entity_info.namespace) ctx = transformer.EntityContext(namer, entity_info) scope_transformer = scoping.ScopeTransformer(ctx) scope_transformer.visit(source) return scope_transformer.scopes
def test_create_source_map(self): def test_fn(x): return x + 1 node, _ = parsing.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) fn_node = node.body[0] anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) converted_code = parsing.ast_to_source(fn_node) source_map = origin_info.create_source_map(fn_node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)
def test_parse_entity(self): def f(x): return x + 1 mod, _ = parsing.parse_entity(f) self.assertEqual('f', mod.body[0].name)
def _build_cfg(self, fn): node, _ = parsing.parse_entity(fn) cfgs = cfg.build(node) return cfgs
def test_state_tracking(self): class LoopState(object): pass class CondState(object): pass class TestTransformer(transformer.Base): def visit(self, node): anno.setanno(node, 'loop_state', self.state[LoopState].value) anno.setanno(node, 'cond_state', self.state[CondState].value) return super(TestTransformer, self).visit(node) def visit_While(self, node): self.state[LoopState].enter() node = self.generic_visit(node) self.state[LoopState].exit() return node def visit_If(self, node): self.state[CondState].enter() node = self.generic_visit(node) self.state[CondState].exit() return node tr = TestTransformer(self._simple_source_info()) def test_function(a): a = 1 while a: _ = 'a' if a > 2: _ = 'b' while True: raise '1' if a > 3: _ = 'c' while True: raise '1' node, _ = parsing.parse_entity(test_function) node = tr.visit(node) fn_body = node.body[0].body outer_while_body = fn_body[1].body self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') first_if_body = outer_while_body[1].body self.assertDifferentAnno(outer_while_body[0], first_if_body[0], 'cond_state') self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state') first_inner_while_body = first_if_body[1].body self.assertSameAnno(first_if_body[0], first_inner_while_body[0], 'cond_state') self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0], 'loop_state') second_if_body = outer_while_body[2].body self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state') self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state') second_inner_while_body = second_if_body[1].body self.assertDifferentAnno(first_inner_while_body[0], second_inner_while_body[0], 'cond_state') self.assertDifferentAnno(first_inner_while_body[0], second_inner_while_body[0], 'loop_state')