コード例 #1
0
ファイル: control_flow.py プロジェクト: google/pyctr
    def visit_While(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'while_stmt'):
            return node

        template = """
      def test_name():
        return test
      def body_name():
        body
      def orelse_name():
        orelse
      overload.while_stmt(test_name, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            overload=self.overload.symbol_name,
            test_name=self.ctx.namer.new_symbol('while_test', set()),
            test=node.test,
            body_name=self.ctx.namer.new_symbol('while_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('while_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            local_writes=tuple(modified_in_cond))

        return node
コード例 #2
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.BODY_SCOPE), ('x', 'y'),
        ('x', 'y', 'z'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.BODY_SCOPE).parent,
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('x', 'y'),
        ('x', 'y', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.ORELSE_SCOPE).parent,
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
コード例 #3
0
ファイル: origin_info_test.py プロジェクト: google/pyctr
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parsing.parse_entity(test_fn)
    fn_node = node.body[0]

    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
コード例 #4
0
    def test_entity_scope_tracking(self):
        class TestTransformer(transformer.Base):

            # The choice of note to assign to is arbitrary. Using Assign because it's
            # easy to find in the tree.
            def visit_Assign(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

            # This will show up in the lambda function.
            def visit_BinOp(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function():
            a = 0

            class TestClass(object):
                def test_method(self):
                    b = 0

                    def inner_function(x):
                        c = 0
                        d = lambda y: (x + y)
                        return c, d

                    return b, inner_function

            return a, TestClass

        node, _ = parsing.parse_entity(test_function)
        node = tr.visit(node)

        test_function_node = node.body[0]
        test_class = test_function_node.body[1]
        test_method = test_class.body[0]
        inner_function = test_method.body[1]
        lambda_node = inner_function.body[1].value

        a = test_function_node.body[0]
        b = test_method.body[0]
        c = inner_function.body[0]
        lambda_expr = lambda_node.body

        self.assertEqual((test_function_node, ),
                         anno.getanno(a, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method),
                         anno.getanno(b, 'enclosing_entities'))
        self.assertEqual(
            (test_function_node, test_class, test_method, inner_function),
            anno.getanno(c, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method,
                          inner_function, lambda_node),
                         anno.getanno(lambda_expr, 'enclosing_entities'))
コード例 #5
0
ファイル: ast_util_test.py プロジェクト: google/pyctr
  def test_rename_symbols_annotations(self):
    node = parsing.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
コード例 #6
0
  def test_params(self):

    def test_fn(a, b):  # pylint: disable=unused-argument
      return b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('b',), ())
    self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b'))

    args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
    self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
コード例 #7
0
ファイル: control_flow.py プロジェクト: google/pyctr
    def visit_For(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'for_stmt'):
            return node

        # TODO(jmd1011): Handle extra_test

        targets = []

        if isinstance(node.target, gast.Tuple) or isinstance(
                node.target, gast.List):
            for target in node.target.elts:
                targets.append(target)
        elif isinstance(node.target, gast.Name):
            targets.append(node.target)
        else:
            raise ValueError(
                'For target must be gast.Tuple, gast.List, or gast.Name, got {}.'
                .format(type(node.target)))

        target_inits = [
            self._make_target_init(target, self.overload) for target in targets
        ]

        template = """
      target_inits
      def body_name():
        body
      def orelse_name():
        orelse
      overload.for_stmt(target, iter_, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            target_inits=target_inits,
            target=node.target,
            body_name=self.ctx.namer.new_symbol('for_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('for_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            overload=self.overload.symbol_name,
            iter_=node.iter,
            local_writes=tuple(modified_in_cond))

        return node
コード例 #8
0
ファイル: origin_info.py プロジェクト: google/pyctr
def create_source_map(nodes, code, filename, indices_in_code):
  """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
      nodes appear in code. The parsing always returns a module when parsing
      code. This argument indicates the position in that module's body at which
      the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parsing.parse_str(code)
  reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

  resolve(reparsed_nodes, code)
  result = {}

  for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
    # Note: generated code might not be mapped back to its origin.
    # TODO(mdanatg): Generated code should always be mapped to something.
    origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
    final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
    if origin_info is None or final_info is None:
      continue

    line_loc = LineLocation(filename, final_info.loc.lineno)

    existing_origin = result.get(line_loc)
    if existing_origin is not None:
      # Overlaps may exist because of child nodes, but almost never to
      # different line locations. Exception make decorated functions, where
      # both lines are mapped to the same line in the AST.

      # Line overlaps: keep bottom node.
      if existing_origin.loc.line_loc == origin_info.loc.line_loc:
        if existing_origin.loc.lineno >= origin_info.loc.lineno:
          continue

      # In case of overlaps, keep the leftmost node.
      if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
        continue

    result[line_loc] = origin_info

  return result
コード例 #9
0
ファイル: activity.py プロジェクト: google/pyctr
 def _node_sets_self_attribute(self, node):
     if anno.hasanno(node, anno.Basic.QN):
         qn = anno.getanno(node, anno.Basic.QN)
         # TODO(mdanatg): The 'self' argument is not guaranteed to be called 'self'.
         if qn.has_attr and qn.parent.qn == ('self', ):
             return True
     return False
コード例 #10
0
  def test_for(self):

    def test_fn(a):
      b = a
      for _ in a:
        c = b
        b -= 1
      return b, c

    node, _ = self._parse_and_analyze(test_fn)
    for_node = node.body[0].body[1]
    self.assertScopeIs(
        anno.getanno(for_node, anno.Static.BODY_SCOPE), ('b',), ('b', 'c'))
    self.assertScopeIs(
        anno.getanno(for_node, anno.Static.BODY_SCOPE).parent, ('a', 'b', 'c'),
        ('b', 'c', '_'))
コード例 #11
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
コード例 #12
0
  def test_nested_if(self):

    def test_fn(b):
      if b > 0:
        if b < 5:
          a = b
        else:
          a = b * b
      return a

    node, _ = self._parse_and_analyze(test_fn)
    inner_if_node = node.body[0].body[0].body[0]
    self.assertScopeIs(
        anno.getanno(inner_if_node, anno.Static.BODY_SCOPE), ('b',), ('a',))
    self.assertScopeIs(
        anno.getanno(inner_if_node, anno.Static.ORELSE_SCOPE), ('b',), ('a',))
コード例 #13
0
ファイル: activity.py プロジェクト: google/pyctr
    def _track_symbol(self, node, composite_writes_alter_parent=False):
        # A QN may be missing when we have an attribute (or subscript) on a function
        # call. Example: a().b
        if not anno.hasanno(node, anno.Basic.QN):
            return
        qn = anno.getanno(node, anno.Basic.QN)

        # When inside a lambda, ignore any of the lambda's arguments.
        # This includes attributes or slices of those arguments.
        for l in self.state[_Lambda]:
            if qn in l.args:
                return
            if qn.owner_set & set(l.args):
                return

        # When inside a comprehension, ignore any of the comprehensions's targets.
        # This includes attributes or slices of those arguments.
        # This is not true in Python2, which leaks symbols.
        if six.PY3:
            for l in self.state[_Comprehension]:
                if qn in l.targets:
                    return
                if qn.owner_set & set(l.targets):
                    return

        if isinstance(node.ctx, gast.Store):
            # In comprehensions, modified symbols are the comprehension targets.
            if six.PY3 and self.state[_Comprehension].level > 0:
                # Like a lambda's args, they are tracked separately in Python3.
                self.state[_Comprehension].targets.add(qn)
            else:
                self.scope.mark_modified(qn)
                if qn.is_composite and composite_writes_alter_parent:
                    self.scope.mark_modified(qn.parent)
                if self._in_aug_assign:
                    self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Load):
            self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Param):
            if self._in_function_def_args:
                # In function defs have the meaning of defining a variable.
                self.scope.mark_modified(qn)
                self.scope.mark_param(qn, self.enclosing_entities[-1])
            elif self.state[_Lambda].level:
                # In lambdas, they are tracked separately.
                self.state[_Lambda].args.add(qn)
            else:
                # TODO(mdanatg): Is this case possible at all?
                raise NotImplementedError(
                    'Param "{}" outside a function arguments or lambda.'.
                    format(qn))
        elif isinstance(node.ctx, gast.Del):
            # The read matches the Python semantics - attempting to delete an
            # undefined symbol is illegal.
            self.scope.mark_read(qn)
            self.scope.mark_deleted(qn)
        else:
            raise ValueError('Unknown context {} for node "{}".'.format(
                type(node.ctx), qn))
コード例 #14
0
  def test_while(self):

    def test_fn(a):
      b = a
      while b > 0:
        c = b
        b -= 1
      return b, c

    node, _ = self._parse_and_analyze(test_fn)
    while_node = node.body[0].body[1]
    self.assertScopeIs(
        anno.getanno(while_node, anno.Static.BODY_SCOPE), ('b',), ('b', 'c'))
    self.assertScopeIs(
        anno.getanno(while_node, anno.Static.BODY_SCOPE).parent,
        ('a', 'b', 'c'), ('b', 'c'))
    self.assertScopeIs(
        anno.getanno(while_node, anno.Static.COND_SCOPE), ('b',), ())
コード例 #15
0
  def test_aug_assign(self):

    def test_fn(a, b):
      a += b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    self.assertScopeIs(
        anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('a', 'b'), ('a'))
コード例 #16
0
 def _process(self, node):
     qn = anno.getanno(node, anno.Basic.QN)
     if qn in self.name_map:
         new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
         # All annotations get carried over.
         for k in anno.keys(node):
             anno.copyanno(node, new_node, k)
         return new_node
     return self.generic_visit(node)
コード例 #17
0
  def test_return_vars_are_read(self):

    def test_fn(a, b, c):  # pylint: disable=unused-argument
      return c

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    self.assertScopeIs(
        anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('c',), ())
コード例 #18
0
  def test_aug_assign_subscripts(self):

    def test_fn(a):
      a[0] += 1

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    self.assertScopeIs(
        anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',))
コード例 #19
0
  def test_lambda_nested(self):

    def test_fn(a, b, c, d, e):  # pylint: disable=unused-argument
      a = lambda a, b: d(lambda b: a + b + c)  # pylint: disable=undefined-variable

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('c', 'd'), ('a',))
    self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
コード例 #20
0
  def test_lambda_complex(self):

    def test_fn(a, b, c, d):  # pylint: disable=unused-argument
      a = (lambda a, b, c: a + b + c)(d, 1, 2) + b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
    self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
コード例 #21
0
  def test_lambda_params_are_isolated(self):

    def test_fn(a, b):  # pylint: disable=unused-argument
      return lambda a: a + b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('b',), ())
    self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
コード例 #22
0
  def test_lambda_captures_reads(self):

    def test_fn(a, b):
      return lambda: a + b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, anno.Static.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('a', 'b'), ())
    # Nothing local to the lambda is tracked.
    self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
コード例 #23
0
ファイル: ast_util_test.py プロジェクト: google/pyctr
 def test_copy_clean_preserves_annotations(self):
   node = parsing.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   anno.setanno(node.body[0], 'foo', 'bar')
   anno.setanno(node.body[0], 'baz', 1)
   new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
   self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
   self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
コード例 #24
0
  def test_print_statement(self):

    def test_fn(a):
      b = 0
      c = 1
      print(a, b)
      return c

    node, _ = self._parse_and_analyze(test_fn)
    print_node = node.body[0].body[2]
    if isinstance(print_node, gast.Print):
      # Python 2
      print_args_scope = anno.getanno(print_node, anno.Static.ARGS_SCOPE)
    else:
      # Python 3
      assert isinstance(print_node, gast.Expr)
      # The call node should be the one being annotated.
      print_node = print_node.value
      print_args_scope = anno.getanno(print_node, anno.Static.ARGS_SCOPE)
    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertScopeIs(print_args_scope, ('a', 'b'), ())
コード例 #25
0
  def test_if_attributes(self):

    def test_fn(a):
      if a > 0:
        a.b = -a.c
        d = 2 * a
      else:
        a.b = a.c
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.BODY_SCOPE), ('a', 'a.c'),
        ('a.b', 'd'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('a', 'a.c'),
        ('a.b', 'd'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.BODY_SCOPE).parent, ('a', 'a.c', 'd'),
        ('a.b', 'd'))
コード例 #26
0
  def test_constructor_attributes(self):

    class TestClass(object):

      def __init__(self, a):
        self.b = a
        self.b.c = 1

    node, _ = self._parse_and_analyze(TestClass)
    init_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(init_node, anno.Static.BODY_SCOPE),
        ('self', 'a', 'self.b'), ('self', 'self.b', 'self.b.c'))
コード例 #27
0
  def test_call_args(self):

    def test_fn(a):
      b = 0
      c = 1
      foo(a, b)  # pylint:disable=undefined-variable
      return c

    node, _ = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[2].value
    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertScopeIs(
        anno.getanno(call_node, anno.Static.ARGS_SCOPE), ('a', 'b'), ())
コード例 #28
0
  def test_if_subscripts(self):

    def test_fn(a, b, c, e):
      if a > 0:
        a[b] = -a[c]
        d = 2 * a
      else:
        a[0] = e
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'),
        ('a[b]', 'd'))
    # TODO(mdanatg): Should subscript writes (a[0] = 1) be considered to read "a"?
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.ORELSE_SCOPE), ('a', 'e'),
        ('a[0]', 'd'))
    self.assertScopeIs(
        anno.getanno(if_node, anno.Static.ORELSE_SCOPE).parent,
        ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]'))
コード例 #29
0
  def test_aug_assign_rvalues(self):

    a = dict(bar=3)

    def foo():
      return a

    def test_fn(x):
      foo()['bar'] += x

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    self.assertScopeIs(
        anno.getanno(fn_node, anno.Static.BODY_SCOPE), ('foo', 'x'), ())
コード例 #30
0
  def test_call_args_attributes(self):

    def foo(*_):
      pass

    def test_fn(a):
      a.c = 0
      foo(a.b, a.c)
      return a.d

    node, _ = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[1].value
    self.assertScopeIs(
        anno.getanno(call_node, anno.Static.ARGS_SCOPE), ('a', 'a.b', 'a.c'),
        ())