def visit_While(self, node): body_scope = anno.getanno(node, anno.Static.BODY_SCOPE) orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE) modified_in_cond = body_scope.modified | orelse_scope.modified node = self.generic_visit(node) if not hasattr(self.overload.module, 'while_stmt'): return node template = """ def test_name(): return test def body_name(): body def orelse_name(): orelse overload.while_stmt(test_name, body_name, orelse_name, (local_writes,)) """ node = templates.replace( template, overload=self.overload.symbol_name, test_name=self.ctx.namer.new_symbol('while_test', set()), test=node.test, body_name=self.ctx.namer.new_symbol('while_body', set()), body=node.body, orelse_name=self.ctx.namer.new_symbol('while_orelse', set()), orelse=node.orelse if node.orelse else gast.Pass(), local_writes=tuple(modified_in_cond)) return node
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, anno.Static.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z')) self.assertScopeIs( anno.getanno(if_node, anno.Static.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs( anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u')) self.assertScopeIs( anno.getanno(if_node, anno.Static.ORELSE_SCOPE).parent, ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parsing.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parsing.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual((test_function_node, ), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def test_rename_symbols_annotations(self): node = parsing.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_params(self): def test_fn(a, b): # pylint: disable=unused-argument return b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE) self.assertScopeIs(body_scope, ('b',), ()) self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b')) args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE) self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
def visit_For(self, node): body_scope = anno.getanno(node, anno.Static.BODY_SCOPE) orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE) modified_in_cond = body_scope.modified | orelse_scope.modified node = self.generic_visit(node) if not hasattr(self.overload.module, 'for_stmt'): return node # TODO(jmd1011): Handle extra_test targets = [] if isinstance(node.target, gast.Tuple) or isinstance( node.target, gast.List): for target in node.target.elts: targets.append(target) elif isinstance(node.target, gast.Name): targets.append(node.target) else: raise ValueError( 'For target must be gast.Tuple, gast.List, or gast.Name, got {}.' .format(type(node.target))) target_inits = [ self._make_target_init(target, self.overload) for target in targets ] template = """ target_inits def body_name(): body def orelse_name(): orelse overload.for_stmt(target, iter_, body_name, orelse_name, (local_writes,)) """ node = templates.replace( template, target_inits=target_inits, target=node.target, body_name=self.ctx.namer.new_symbol('for_body', set()), body=node.body, orelse_name=self.ctx.namer.new_symbol('for_orelse', set()), orelse=node.orelse if node.orelse else gast.Pass(), overload=self.overload.symbol_name, iter_=node.iter, local_writes=tuple(modified_in_cond)) return node
def create_source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: nodes: Iterable[ast.AST, ...] code: Text filename: Optional[Text] indices_in_code: Union[int, Iterable[int, ...]], the positions at which nodes appear in code. The parsing always returns a module when parsing code. This argument indicates the position in that module's body at which the corresponding of node should appear. Returns: Dict[CodeLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parsing.parse_str(code) reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] resolve(reparsed_nodes, code) result = {} for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdanatg): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue line_loc = LineLocation(filename, final_info.loc.lineno) existing_origin = result.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue result[line_loc] = origin_info return result
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdanatg): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self', ): return True return False
def test_for(self): def test_fn(a): b = a for _ in a: c = b b -= 1 return b, c node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(for_node, anno.Static.BODY_SCOPE), ('b',), ('b', 'c')) self.assertScopeIs( anno.getanno(for_node, anno.Static.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c', '_'))
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def test_nested_if(self): def test_fn(b): if b > 0: if b < 5: a = b else: a = b * b return a node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIs( anno.getanno(inner_if_node, anno.Static.BODY_SCOPE), ('b',), ('a',)) self.assertScopeIs( anno.getanno(inner_if_node, anno.Static.ORELSE_SCOPE), ('b',), ('a',))
def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) # When inside a lambda, ignore any of the lambda's arguments. # This includes attributes or slices of those arguments. for l in self.state[_Lambda]: if qn in l.args: return if qn.owner_set & set(l.args): return # When inside a comprehension, ignore any of the comprehensions's targets. # This includes attributes or slices of those arguments. # This is not true in Python2, which leaks symbols. if six.PY3: for l in self.state[_Comprehension]: if qn in l.targets: return if qn.owner_set & set(l.targets): return if isinstance(node.ctx, gast.Store): # In comprehensions, modified symbols are the comprehension targets. if six.PY3 and self.state[_Comprehension].level > 0: # Like a lambda's args, they are tracked separately in Python3. self.state[_Comprehension].targets.add(qn) else: self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_modified(qn.parent) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): if self._in_function_def_args: # In function defs have the meaning of defining a variable. self.scope.mark_modified(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) elif self.state[_Lambda].level: # In lambdas, they are tracked separately. self.state[_Lambda].args.add(qn) else: # TODO(mdanatg): Is this case possible at all? raise NotImplementedError( 'Param "{}" outside a function arguments or lambda.'. format(qn)) elif isinstance(node.ctx, gast.Del): # The read matches the Python semantics - attempting to delete an # undefined symbol is illegal. self.scope.mark_read(qn) self.scope.mark_deleted(qn) else: raise ValueError('Unknown context {} for node "{}".'.format( type(node.ctx), qn))
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node, _ = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(while_node, anno.Static.BODY_SCOPE), ('b',), ('b', 'c')) self.assertScopeIs( anno.getanno(while_node, anno.Static.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c')) self.assertScopeIs( anno.getanno(while_node, anno.Static.COND_SCOPE), ('b',), ())
def test_aug_assign(self): def test_fn(a, b): a += b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIs( anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('a', 'b'), ('a'))
def _process(self, node): qn = anno.getanno(node, anno.Basic.QN) if qn in self.name_map: new_node = gast.Name(str(self.name_map[qn]), node.ctx, None) # All annotations get carried over. for k in anno.keys(node): anno.copyanno(node, new_node, k) return new_node return self.generic_visit(node)
def test_return_vars_are_read(self): def test_fn(a, b, c): # pylint: disable=unused-argument return c node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIs( anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('c',), ())
def test_aug_assign_subscripts(self): def test_fn(a): a[0] += 1 node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIs( anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',))
def test_lambda_nested(self): def test_fn(a, b, c, d, e): # pylint: disable=unused-argument a = lambda a, b: d(lambda b: a + b + c) # pylint: disable=undefined-variable node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE) self.assertScopeIs(body_scope, ('c', 'd'), ('a',)) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
def test_lambda_complex(self): def test_fn(a, b, c, d): # pylint: disable=unused-argument a = (lambda a, b, c: a + b + c)(d, 1, 2) + b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE) self.assertScopeIs(body_scope, ('b', 'd'), ('a',)) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
def test_lambda_params_are_isolated(self): def test_fn(a, b): # pylint: disable=unused-argument return lambda a: a + b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE) self.assertScopeIs(body_scope, ('b',), ()) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
def test_lambda_captures_reads(self): def test_fn(a, b): return lambda: a + b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE) self.assertScopeIs(body_scope, ('a', 'b'), ()) # Nothing local to the lambda is tracked. self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
def test_copy_clean_preserves_annotations(self): node = parsing.parse_str( textwrap.dedent(""" def f(a): return a + 1 """)) anno.setanno(node.body[0], 'foo', 'bar') anno.setanno(node.body[0], 'baz', 1) new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node, _ = self._parse_and_analyze(test_fn) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, anno.Static.ARGS_SCOPE) else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, anno.Static.ARGS_SCOPE) # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs(print_args_scope, ('a', 'b'), ())
def test_if_attributes(self): def test_fn(a): if a > 0: a.b = -a.c d = 2 * a else: a.b = a.c d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, anno.Static.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd')) self.assertScopeIs( anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('a', 'a.c'), ('a.b', 'd')) self.assertScopeIs( anno.getanno(if_node, anno.Static.BODY_SCOPE).parent, ('a', 'a.c', 'd'), ('a.b', 'd'))
def test_constructor_attributes(self): class TestClass(object): def __init__(self, a): self.b = a self.b.c = 1 node, _ = self._parse_and_analyze(TestClass) init_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(init_node, anno.Static.BODY_SCOPE), ('self', 'a', 'self.b'), ('self', 'self.b', 'self.b.c'))
def test_call_args(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs( anno.getanno(call_node, anno.Static.ARGS_SCOPE), ('a', 'b'), ())
def test_if_subscripts(self): def test_fn(a, b, c, e): if a > 0: a[b] = -a[c] d = 2 * a else: a[0] = e d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, anno.Static.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), ('a[b]', 'd')) # TODO(mdanatg): Should subscript writes (a[0] = 1) be considered to read "a"? self.assertScopeIs( anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('a', 'e'), ('a[0]', 'd')) self.assertScopeIs( anno.getanno(if_node, anno.Static.ORELSE_SCOPE).parent, ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]'))
def test_aug_assign_rvalues(self): a = dict(bar=3) def foo(): return a def test_fn(x): foo()['bar'] += x node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIs( anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('foo', 'x'), ())
def test_call_args_attributes(self): def foo(*_): pass def test_fn(a): a.c = 0 foo(a.b, a.c) return a.d node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[1].value self.assertScopeIs( anno.getanno(call_node, anno.Static.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ())