示例#1
0
    def attr(self, node, attr_name, attr_vals, deps=None, default=None):
        """Add the formatted data stored for a given attribute on this node.

    If any of the dependent attributes of the node have changed since it was
    annotated, then the stored formatted data for this attr_name is no longer
    valid, and we must use the default instead.
    
    Arguments:
      node: (ast.AST) An AST node to retrieve formatting information from.
      attr_name: (string) Name to load the formatting information from.
      attr_vals: (list of functions/strings) Unused here.
      deps: (optional, set of strings) Attributes of the node which the stored
        formatting data depends on.
      default: (string) Default formatted data for this attribute.
    """
        del attr_vals
        if not hasattr(node, '_printer_info') or node._printer_info[attr_name]:
            return
        node._printer_info[attr_name] = True
        val = fmt.get(node, attr_name)
        if (val is None or deps and any(
                getattr(node, dep, None) != fmt.get(node, dep + '__src')
                for dep in deps)):
            val = default
        self.code += val if val is not None else ''
示例#2
0
                def test(self):
                    with open(input_file, 'r') as handle:
                        src = handle.read()
                    t = ast_utils.parse(src, py_ver)
                    annotator = annotate.get_ast_annotator(py_ver)(src)
                    annotator.visit(t)

                    def escape(s):
                        return '' if s is None else s.replace('\n', '\\n')

                    result = '\n'.join(
                        '{0:12} {1:20} \tprefix=|{2}|\tsuffix=|{3}|\tindent=|{4}|'
                        .format(
                            str((getattr(n, 'lineno', -1),
                                 getattr(n, 'col_offset', -1))),
                            type(n).__name__ + ' ' + _get_node_identifier(n),
                            escape(fmt.get(n, 'prefix')),
                            escape(fmt.get(n, 'suffix')),
                            escape(fmt.get(n, 'indent')))
                        for n in pasta.ast_walk(t, py_ver)) + '\n'

                    # If specified, write the golden data instead of checking it
                    if getattr(self, 'generate_goldens', False):
                        return

                    try:
                        with io.open(golden_file, 'r', encoding='UTF-8') as f:
                            golden = f.read()
                    except IOError:
                        self.fail('Missing golden data.')

                    self.assertMultiLineEqual(golden, result)
示例#3
0
def replace_child(parent, node, replace_with):
    """Replace a node's child with another node while preserving formatting.

  Arguments:
    parent: (ast.AST) Parent node to replace a child of.
    node: (ast.AST) Child node to replace.
    replace_with: (ast.AST) New child node.
  """
    # TODO(soupytwist): Don't refer to the formatting dict directly
    if hasattr(node, fmt.PASTA_DICT):
        fmt.set(replace_with, 'prefix', fmt.get(node, 'prefix'))
        fmt.set(replace_with, 'suffix', fmt.get(node, 'suffix'))
    for field in parent._fields:
        field_val = getattr(parent, field, None)
        if field_val == node:
            setattr(parent, field, replace_with)
            return
        elif isinstance(field_val, list):
            try:
                field_val[field_val.index(node)] = replace_with
                return
            except ValueError:
                pass
    raise errors.InvalidAstError('Node %r is not a child of %r' %
                                 (node, parent))
示例#4
0
 def test_indent_levels_same_line(self):
     src = 'if a: b; c\n'
     t = pasta.parse(src)
     if_node = t.body[0]
     b, c = if_node.body
     self.assertIsNone(fmt.get(b, 'indent_diff'))
     self.assertIsNone(fmt.get(c, 'indent_diff'))
示例#5
0
      def test(self):
        with open(input_file, 'r') as handle:
          src = handle.read()
        t = ast_utils.parse(src)
        annotator = annotate.AstAnnotator(src)
        annotator.visit(t)

        def escape(s):
          return '' if s is None else s.replace('\n', '\\n')

        result = '\n'.join(
            "{0:12} {1:20} \tprefix=|{2}|\tsuffix=|{3}|\tindent=|{4}|".format(
                str((getattr(n, 'lineno', -1), getattr(n, 'col_offset', -1))),
                type(n).__name__ + ' ' + _get_node_identifier(n),
                escape(fmt.get(n, 'prefix')),
                escape(fmt.get(n, 'suffix')),
                escape(fmt.get(n, 'indent')))
            for n in ast.walk(t)) + '\n'

        # If specified, write the golden data instead of checking it
        if getattr(self, 'generate_goldens', False):
          if not os.path.isdir(os.path.dirname(golden_file)):
            os.makedirs(os.path.dirname(golden_file))
          with open(golden_file, 'w') as f:
            f.write(result)
          print('Wrote: ' + golden_file)
          return

        try:
          with open(golden_file, 'r') as f:
            golden = f.read()
        except IOError:
          self.fail('Missing golden data.')

        self.assertMultiLineEqual(golden, result)
示例#6
0
 def test_indent_multiline_string_with_newline(self):
     src = textwrap.dedent('''\
   class A:
     """Doc\n
        string."""
     pass
   ''')
     t = pasta.parse(src, py_ver)
     docstring, pass_stmt = t.body[0].body
     self.assertEqual('  ', fmt.get(docstring, 'indent'))
     self.assertEqual('  ', fmt.get(pass_stmt, 'indent'))
示例#7
0
    def close_scope(self,
                    node,
                    prefix_attr='prefix',
                    suffix_attr='suffix',
                    trailing_comma=False,
                    single_paren=False):
        """Close a parenthesized scope on the given node, if one is open."""
        # Ensures the prefix + suffix are not None
        if fmt.get(node, prefix_attr) is None:
            fmt.set(node, prefix_attr, '')
        if fmt.get(node, suffix_attr) is None:
            fmt.set(node, suffix_attr, '')

        if not self._parens or node not in self._scope_stack[-1]:
            return
        symbols = {')'}
        if trailing_comma:
            symbols.add(',')
        parsed_to_i = self._i
        parsed_to_loc = prev_loc = self._loc
        encountered_paren = False
        result = ''

        for tok in self.takewhile(
                lambda t: t.type in FORMATTING_TOKENS or t.src in symbols):
            # Consume all space up to this token
            result += self._space_between(prev_loc, tok)
            if tok.src == ')' and single_paren and encountered_paren:
                self.rewind()
                parsed_to_i = self._i
                parsed_to_loc = tok.start
                fmt.append(node, suffix_attr, result)
                break

            # Consume the token itself
            result += tok.src

            if tok.src == ')':
                # Close out the open scope
                encountered_paren = True
                self._scope_stack.pop()
                fmt.prepend(node, prefix_attr, self._parens.pop())
                fmt.append(node, suffix_attr, result)
                result = ''
                parsed_to_i = self._i
                parsed_to_loc = tok.end
                if not self._parens or node not in self._scope_stack[-1]:
                    break
            prev_loc = tok.end

        # Reset back to the last place where we parsed anything
        self._i = parsed_to_i
        self._loc = parsed_to_loc
示例#8
0
def get_last_child(node):
  """Get the last child node of a block statement.

  The input must be a block statement (e.g. ast.For, ast.With, etc).

  Examples:
    1. with first():
         second()
         last()

    2. try:
         first()
       except:
         second()
       finally:
         last()

  In both cases, the last child is the node for `last`.
  """
  if isinstance(node, ast27.Module) or isinstance(node, ast3.Module):
    try:
      return node.body[-1]
    except IndexError:
      return None
  if isinstance(node, ast27.If) or isinstance(node, ast3.If):
    if (len(node.orelse) == 1 and isinstance(node.orelse[0],
                                             (ast27.If, ast3.If)) and
        fmt.get(node.orelse[0], 'is_elif')):
      return get_last_child(node.orelse[0])
    if node.orelse:
      return node.orelse[-1]
  elif isinstance(node, ast27.With) or isinstance(node, ast3.With):
    if (len(node.body) == 1 and isinstance(node.body[0],
                                           (ast27.With, ast3.With)) and
        fmt.get(node.body[0], 'is_continued')):
      return get_last_child(node.body[0])
  elif isinstance(node, ast3.Try):
    if node.finalbody:
      return node.finalbody[-1]
    if node.orelse:
      return node.orelse[-1]
  elif isinstance(node, ast27.TryFinally):
    if node.finalbody:
      return node.finalbody[-1]
  elif isinstance(node, ast27.TryExcept):
    if node.orelse:
      return node.orelse[-1]
    if node.handlers:
      return get_last_child(node.handlers[-1])
  return node.body[-1]
示例#9
0
 def visit_Module(self, node):
     self.prefix(node)
     bom = fmt.get(node, 'bom')
     if bom is not None:
         self.code += bom
     self.generic_visit(node)
     self.suffix(node)
示例#10
0
 def optional_token(self, node, attr_name, token_val,
                    allow_whitespace_prefix=False, default=False):
   del allow_whitespace_prefix
   value = fmt.get(node, attr_name)
   if value is None and default:
     value = token_val
   self.code += value or ''
示例#11
0
 def test_fstring(self):
     src = 'f"a {b} c d {e}"'
     t = pasta.parse(src, py_ver)
     node = t.body[0].value
     self.assertEqual(
         fmt.get(node, 'content'),
         'f"a {__pasta_fstring_val_0__} c d {__pasta_fstring_val_1__}"')
示例#12
0
 def test_fstring_escaping(self):
   src = 'f"a {{{b} {{c}}"'
   t = pasta.parse(src)
   node = t.body[0].value
   self.assertEqual(
       fmt.get(node, 'content'),
       'f"a {{{__pasta_fstring_val_0__} {{c}}"')
示例#13
0
 def visit_Constant(self, node):
     self.prefix(node)
     if node.value is Ellipsis:
         content = '...'
     else:
         content = fmt.get(node, 'content')
     self.code += content if content is not None else repr(node.s)
     self.suffix(node)
示例#14
0
    def test_indent_levels(self):
        src = textwrap.dedent('''\
        foo('begin')
        if a:
          foo('a1')
          if b:
            foo('b1')
            if c:
              foo('c1')
            foo('b2')
          foo('a2')
        foo('end')
        ''')
        t = pasta.parse(src)
        call_nodes = ast_utils.find_nodes_by_type(t, (ast.Call, ))
        call_nodes.sort(key=lambda node: node.lineno)
        begin, a1, b1, c1, b2, a2, end = call_nodes

        self.assertEqual('', fmt.get(begin, 'indent'))
        self.assertEqual('  ', fmt.get(a1, 'indent'))
        self.assertEqual('    ', fmt.get(b1, 'indent'))
        self.assertEqual('      ', fmt.get(c1, 'indent'))
        self.assertEqual('    ', fmt.get(b2, 'indent'))
        self.assertEqual('  ', fmt.get(a2, 'indent'))
        self.assertEqual('', fmt.get(end, 'indent'))
示例#15
0
  def test_scope_trailing_comma(self):
    template = 'def foo(a, b{trailing_comma}): pass'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma.lstrip(' ') + ')',
                       fmt.get(tree.body[0], 'args_suffix'))

    template = 'class Foo(a, b{trailing_comma}): pass'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma.lstrip(' ') + ')',
                       fmt.get(tree.body[0], 'bases_suffix'))

    template = 'from mod import (a, b{trailing_comma})'
    for trailing_comma in ('', ',', ' , '):
      tree = pasta.parse(template.format(trailing_comma=trailing_comma))
      self.assertEqual(trailing_comma + ')',
                       fmt.get(tree.body[0], 'names_suffix'))
示例#16
0
 def test_tabs_below_spaces_and_tab(self):
   for num_spaces in range(1, 8):
     t = pasta.parse(textwrap.dedent('''\
         if a:
         {WS}{ONETAB}if b:
         {ONETAB}{ONETAB}c
         ''').format(ONETAB='\t', WS=' ' * num_spaces))
     node_c = t.body[0].body[0].body[0]
     self.assertEqual(fmt.get(node_c, 'indent_diff'), '\t')
示例#17
0
        def test_indent_extra_newlines(self):
            src = textwrap.dedent("""\
          if a:

            b
          """)
            t = pasta.parse(src, py_ver)
            if_node = t.body[0]
            b = if_node.body[0]
            self.assertEqual('  ', fmt.get(b, 'indent_diff'))
示例#18
0
文件: annotate.py 项目: junk13/pasta
 def wrapped(self, node, *args, **kwargs):
   self.prefix(node, default=self._indent)
   f(self, node, *args, **kwargs)
   if hasattr(self, 'block_suffix'):
     last_child = ast_utils.get_last_child(node)
     # Workaround for ast.Module which does not have a lineno
     if last_child and last_child.lineno != getattr(node, 'lineno', 0):
       indent = (fmt.get(last_child, 'prefix') or '\n').splitlines()[-1]
       self.block_suffix(node, indent)
   else:
     self.suffix(node, comment=True)
示例#19
0
 def test_tab_below_spaces(self):
     for num_spaces in range(1, 8):
         t = pasta.parse(
             textwrap.dedent("""\
     if a:
     {WS}if b:
     {ONETAB}c
     """).format(ONETAB='\t', WS=' ' * num_spaces), py_ver)
         node_c = t.body[0].body[0].body[0]
         self.assertEqual(fmt.get(node_c, 'indent_diff'),
                          ' ' * (8 - num_spaces))
示例#20
0
    def test_indent_extra_newlines_with_comment(self):
        src = textwrap.dedent('''\
        if a:
            #not here

          b
        ''')
        t = pasta.parse(src)
        if_node = t.body[0]
        b = if_node.body[0]
        self.assertEqual('  ', fmt.get(b, 'indent_diff'))
示例#21
0
        def test_block_suffix(self):
            src_tpl = textwrap.dedent("""\
          {open_block}
            pass #a
            #b
              #c

            #d
          #e
          a
          """)
            test_cases = (
                # first: attribute of the node with the last block
                # second: code snippet to open a block
                ('body', 'def x():'),
                ('body', 'class X:'),
                ('body', 'if x:'),
                ('orelse', 'if x:\n  y\nelse:'),
                ('body', 'if x:\n  y\nelif y:'),
                ('body', 'while x:'),
                ('orelse', 'while x:\n  y\nelse:'),
                ('finalbody', 'try:\n  x\nfinally:'),
                ('body', 'try:\n  x\nexcept:'),
                ('orelse', 'try:\n  x\nexcept:\n  y\nelse:'),
                ('body', 'with x:'),
                ('body', 'with x, y:'),
                ('body', 'with x:\n with y:'),
                ('body', 'for x in y:'),
            )

            def is_node_for_suffix(node, children_attr):
                # Return True if this node contains the 'pass' statement
                val = getattr(node, children_attr, None)
                return isinstance(val, list) and (type(val[0]) == ast27.Pass
                                                  or type(val[0]) == ast3.Pass)

            for children_attr, open_block in test_cases:
                src = src_tpl.format(open_block=open_block)
                t = pasta.parse(src, py_ver)
                node_finder = ast_utils.get_find_node_visitor(
                    lambda node: is_node_for_suffix(node, children_attr),
                    py_ver)
                node_finder.visit(t)
                node = node_finder.results[0]
                expected = '  #b\n    #c\n\n  #d\n'
                actual = str(fmt.get(node, 'block_suffix_%s' % children_attr))
                self.assertMultiLineEqual(
                    expected, actual,
                    'Incorrect suffix for code:\n%s\nNode: %s (line %d)\nDiff:\n%s'
                    % (src, node, node.lineno, '\n'.join(
                        _get_diff(actual, expected))))
                self.assertMultiLineEqual(src, pasta.dump(t, py_ver))
示例#22
0
文件: annotate.py 项目: junk13/pasta
 def indented(self, node, children_attr):
   prev_indent = self._indent
   prev_indent_diff = self._indent_diff
   new_diff = fmt.get(node, 'indent')
   if new_diff is None:
     new_diff = '  '
   self._indent_diff = new_diff
   self._indent = prev_indent + self._indent_diff
   for child in getattr(node, children_attr):
     yield child
   self.attr(node, 'block_suffix_%s' % children_attr, [])
   self._indent = prev_indent
   self._indent_diff = prev_indent_diff
示例#23
0
    def _modify_function_name(func_def_node, new_func_name):
        """Modify function name"""
        if not isinstance(func_def_node, ast.FunctionDef):
            raise NodeTypeNotSupport('It is not ast.FunctionDef node type.')

        old_func_name = func_def_node.name
        func_def_node.name = new_func_name

        # Modify formatting information stored by pasta
        old_function_def = fmt.get(func_def_node, 'function_def')
        if old_function_def:
            new_function_def = old_function_def.replace(
                old_func_name, new_func_name)
            fmt.set(func_def_node, 'function_def', new_function_def)
            fmt.set(func_def_node, 'name__src', new_func_name)
示例#24
0
    def visit_JoinedStr(self, node):
        self.prefix(node)
        content = fmt.get(node, 'content')

        if content is None:
            parts = []
            for val in node.values:
                if isinstance(val, ast.Str):
                    parts.append(val.s)
                else:
                    parts.append(fstring_utils.placeholder(len(parts)))
            content = repr(''.join(parts))

        values = [to_str(v) for v in fstring_utils.get_formatted_values(node)]
        self.code += fstring_utils.perform_replacements(content, values)
        self.suffix(node)
示例#25
0
def to_str(tree):
    """Convenient function to get the python source for an AST."""
    p = Printer()

    # Detect the most prevalent indentation style in the file and use it when
    # printing indented nodes which don't have formatting data.
    seen_indent_diffs = collections.defaultdict(lambda: 0)
    for node in ast.walk(tree):
        seen_indent_diffs[fmt.get(node, 'indent_diff', '')] += 1
    del seen_indent_diffs['']
    if seen_indent_diffs:
        indent_diff, _ = max(six.iteritems(seen_indent_diffs),
                             key=lambda tup: tup[1] if tup[0] else -1)
        p.set_default_indent_diff(indent_diff)

    p.visit(tree)
    return p.code
示例#26
0
文件: annotate.py 项目: junk13/pasta
  def visit_If(self, node):
    tok = 'elif' if fmt.get(node, 'is_elif') else 'if'
    self.attr(node, 'open_if', [tok, self.ws], default=tok + ' ')
    self.visit(node.test)
    self.attr(node, 'open_block', [self.ws, ':', self.ws_oneline],
              default=':\n')

    for stmt in self.indented(node, 'body'):
      self.visit(stmt)

    if node.orelse:
      if (len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) and
          self.check_is_elif(node.orelse[0])):
        fmt.set(node.orelse[0], 'is_elif', True)
        self.visit(node.orelse[0])
      else:
        self.attr(node, 'elseprefix', [self.ws])
        self.token('else')
        self.attr(node, 'open_else', [self.ws, ':', self.ws_oneline],
                  default=':\n')
        for stmt in self.indented(node, 'orelse'):
          self.visit(stmt)
示例#27
0
    def test_indent_depths(self):
        template = 'if a:\n{first}if b:\n{first}{second}foo()\n'
        indents = (' ', ' ' * 2, ' ' * 4, ' ' * 8, '\t', '\t' * 2)

        for first, second in itertools.product(indents, indents):
            src = template.format(first=first, second=second)
            t = pasta.parse(src)
            outer_if_node = t.body[0]
            inner_if_node = outer_if_node.body[0]
            call_node = inner_if_node.body[0]

            self.assertEqual('', fmt.get(outer_if_node, 'indent'))
            self.assertEqual('', fmt.get(outer_if_node, 'indent_diff'))
            self.assertEqual(first, fmt.get(inner_if_node, 'indent'))
            self.assertEqual(first, fmt.get(inner_if_node, 'indent_diff'))
            self.assertEqual(first + second, fmt.get(call_node, 'indent'))
            self.assertEqual(second, fmt.get(call_node, 'indent_diff'))
示例#28
0
 def test_no_block_suffix_for_single_line_statement(self):
     src = 'if x:  return y\n  #a\n#b\n'
     t = pasta.parse(src)
     self.assertIsNone(fmt.get(t.body[0], 'block_suffix_body'))
示例#29
0
 def test_module_suffix(self):
     src = 'foo\n#bar\n\n#baz\n'
     t = pasta.parse(src)
     self.assertEquals(src[src.index('#bar'):], fmt.get(t, 'suffix'))
示例#30
0
 def test_statement_prefix_suffix(self):
     src = 'a\n\ndef foo():\n  return bar\n\n\nb\n'
     t = pasta.parse(src)
     self.assertEqual('\n', fmt.get(t.body[1], 'prefix'))
     self.assertEqual('', fmt.get(t.body[1], 'suffix'))