Example #1
0
def _fix_tokens(contents_text: str, min_version: Version) -> str:
    remove_u = (min_version >= (3, )
                or _imports_future(contents_text, 'unicode_literals'))

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.name == 'NUMBER':
            tokens[i] = token._replace(src=_fix_long(_fix_octal(token.src)))
        elif token.name == 'STRING':
            tokens[i] = _fix_ur_literals(tokens[i])
            if remove_u:
                tokens[i] = _remove_u_prefix(tokens[i])
            tokens[i] = _fix_escape_sequences(tokens[i])
        elif token.src == '(':
            _fix_extraneous_parens(tokens, i)
        elif token.src == 'format' and i > 0 and tokens[i - 1].src == '.':
            _fix_format_literal(tokens, i - 2)
        elif token.src == 'encode' and i > 0 and tokens[i - 1].src == '.':
            _fix_encode_to_binary(tokens, i)
        elif (min_version >= (3, ) and token.utf8_byte_offset == 0
              and token.line < 3 and token.name == 'COMMENT'
              and tokenize.cookie_re.match(token.src)):
            del tokens[i]
            assert tokens[i].name == 'NL', tokens[i].name
            del tokens[i]
        elif token.src == 'from' and token.utf8_byte_offset == 0:
            _fix_import_removals(tokens, i, min_version)
    return tokens_to_src(tokens).lstrip()
def _fix_six(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindSixUsage()
    visitor.visit(ast_obj)

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.simple_names:
            node = visitor.simple_names[token.offset]
            tokens[i] = Token('CODE', SIX_SIMPLE_ATTRS[node.id])
        elif token.offset in visitor.simple_attrs:
            node = visitor.simple_attrs[token.offset]
            if tokens[i + 1].src == '.' and tokens[i + 2].src == node.attr:
                tokens[i:i + 3] = [Token('CODE', SIX_SIMPLE_ATTRS[node.attr])]
        elif token.offset in visitor.remove_decorators:
            if tokens[i - 1].src == '@':
                end = i + 1
                while tokens[end].name != 'NEWLINE':
                    end += 1
                del tokens[i - 1:end + 1]

    return tokens_to_src(tokens)
Example #3
0
def fix_file(filename: str, show_diff: bool = False, dry_run: bool = False) -> int:
    with open(filename, 'rb') as f:
        contents_bytes = f.read()

    try:
        contents_text = contents_bytes.decode()
    except UnicodeDecodeError:
        print(f'{filename} is non-utf8 (not supported)')
        return 1

    tokens = tokenize_rt.src_to_tokens(contents_text)

    tokens_no_comments = _remove_comments(tokens)
    src_no_comments = tokenize_rt.tokens_to_src(tokens_no_comments)

    if src_no_comments == contents_text:
        return 0

    with tempfile.NamedTemporaryFile(
        dir=os.path.dirname(filename),
        prefix=os.path.basename(filename),
        suffix='.py',
    ) as tmpfile:
        tmpfile.write(src_no_comments.encode())
        tmpfile.flush()
        flake8_results = _run_flake8(tmpfile.name)

    if any('E999' in v for v in flake8_results.values()):
        print(f'{filename}: syntax error (skipping)')
        return 0

    for i, token in tokenize_rt.reversed_enumerate(tokens):
        if token.name != 'COMMENT':
            continue

        if NOQA_RE.search(token.src):
            _rewrite_noqa_comment(tokens, i, flake8_results)
        elif NOQA_FILE_RE.match(token.src) and not flake8_results:
            if i == 0 or tokens[i - 1].name == 'NEWLINE':
                del tokens[i: i + 2]
            else:
                _remove_comment(tokens, i)

    newsrc = tokenize_rt.tokens_to_src(tokens)
    if newsrc != contents_text:
        if (show_diff or dry_run):
            diff = difflib.unified_diff(
                contents_text.splitlines(keepends=True),
                newsrc.splitlines(keepends=True),
                fromfile=filename,
                tofile=filename,
            )
            print(''.join(diff), end='')
        if (not dry_run):
            print(f'Rewriting {filename}')
            with open(filename, 'wb') as f:
                f.write(newsrc.encode())
        return 1
    else:
        return 0
Example #4
0
def _fix_py2_compatible(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text
    visitor = Py2CompatibleVisitor()
    visitor.visit(ast_obj)
    if not any((
            visitor.dicts,
            visitor.sets,
            visitor.set_empty_literals,
            visitor.is_literal,
    )):
        return contents_text

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.dicts:
            _process_dict_comp(tokens, i, visitor.dicts[token.offset])
        elif token.offset in visitor.set_empty_literals:
            _process_set_empty_literal(tokens, i)
        elif token.offset in visitor.sets:
            _process_set_literal(tokens, i, visitor.sets[token.offset])
        elif token.offset in visitor.is_literal:
            _process_is_literal(tokens, i, visitor.is_literal[token.offset])
    return tokens_to_src(tokens)
Example #5
0
def _fix_percent_format(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPercentFormats()
    visitor.visit(ast_obj)

    if not visitor.found:
        return contents_text

    tokens = src_to_tokens(contents_text)

    for i, token in reversed_enumerate(tokens):
        node = visitor.found.get(token.offset)
        if node is None:
            continue

        # no .format() equivalent for bytestrings in py3
        # note that this code is only necessary when running in python2
        if _is_bytestring(tokens[i].src):  # pragma: no cover (py2-only)
            continue

        if isinstance(node.right, ast.Tuple):
            _fix_percent_format_tuple(tokens, i, node)
        elif isinstance(node.right, ast.Dict):
            _fix_percent_format_dict(tokens, i, node)

    return tokens_to_src(tokens)
Example #6
0
def _fix_plugins(contents_text: str, settings: Settings) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    callbacks = visit(FUNCS, ast_obj, settings)

    if not callbacks:
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text

    _fixup_dedent_tokens(tokens)

    for i, token in reversed_enumerate(tokens):
        if not token.src:
            continue
        # though this is a defaultdict, by using `.get()` this function's
        # self time is almost 50% faster
        for callback in callbacks.get(token.offset, ()):
            callback(i, tokens)

    return tokens_to_src(tokens)
def remove_trailing_semicolon(src: str) -> Tuple[str, bool]:
    """Remove trailing semicolon from Jupyter notebook cell.

    For example,

        fig, ax = plt.subplots()
        ax.plot(x_data, y_data);  # plot data

    would become

        fig, ax = plt.subplots()
        ax.plot(x_data, y_data)  # plot data

    Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
    ``tokenize_rt`` so that round-tripping works fine.
    """
    from tokenize_rt import (
        src_to_tokens,
        tokens_to_src,
        reversed_enumerate,
    )

    tokens = src_to_tokens(src)
    trailing_semicolon = False
    for idx, token in reversed_enumerate(tokens):
        if token.name in TOKENS_TO_IGNORE:
            continue
        if token.name == "OP" and token.src == ";":
            del tokens[idx]
            trailing_semicolon = True
        break
    if not trailing_semicolon:
        return src, False
    return tokens_to_src(tokens), True
Example #8
0
def _fix_fstrings(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindSimpleFormats()
    visitor.visit(ast_obj)

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        node = visitor.found.get(token.offset)
        if node is None:
            continue

        if _is_bytestring(token.src):  # pragma: no cover (py2-only)
            continue

        paren = i + 3
        if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
            continue

        # we don't actually care about arg position, so we pass `node`
        victims = _victims(tokens, paren, node, gen=False)
        end = victims.ends[-1]
        # if it spans more than one line, bail
        if tokens[end].line != token.line:
            continue

        tokens[i] = token._replace(src=_to_fstring(token.src, node))
        del tokens[i + 1:end + 1]

    return tokens_to_src(tokens)
Example #9
0
def _has_trailing_semicolon(src: str) -> Tuple[str, bool]:
    """
    Check if cell has trailing semicolon.

    Parameters
    ----------
    src
        Notebook cell source.

    Returns
    -------
    bool
        Whether notebook has trailing semicolon.
    """
    tokens = tokenize_rt.src_to_tokens(src)
    trailing_semicolon = False
    for idx, token in tokenize_rt.reversed_enumerate(tokens):
        if not token.src.strip(" \n") or token.name == "COMMENT":
            continue
        if token.name == "OP" and token.src == ";":
            tokens[idx] = token._replace(src="")
            trailing_semicolon = True
        break
    if not trailing_semicolon:
        return src, False
    return tokenize_rt.tokens_to_src(tokens), True
Example #10
0
def _remove_comments(tokens: Tokens) -> Tokens:
    tokens = list(tokens)
    for i, token in tokenize_rt.reversed_enumerate(tokens):
        if token.name == 'COMMENT':
            if NOQA_RE.search(token.src):
                _rewrite_noqa_comment(tokens, i, collections.defaultdict(set))
            elif NOQA_FILE_RE.search(token.src):
                _remove_comment(tokens, i)
    return tokens
Example #11
0
def _remove_comments(tokens: Tokens) -> Tokens:
    tokens = list(tokens)
    for i, token in tokenize_rt.reversed_enumerate(tokens):
        if token.name == 'COMMENT':
            if NOQA_RE.search(token.src):
                _mask_noqa_comment(tokens, i)
            elif NOQA_FILE_RE.search(token.src):
                _remove_comment(tokens, i)
    return tokens
Example #12
0
File: t.py Project: asottile/t
def _fix_calls(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = Visitor()
    visitor.visit(ast_obj)

    if not visitor.calls:
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text

    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.calls:
            visitor.calls.discard(token.offset)

            # search forward for the opening brace
            while tokens[i].src != '(':
                i += 1

            call_start = i
            i += 1
            brace_depth = 1
            start = -1
            end = -1

            while brace_depth:
                if tokens[i].src in {'(', '{', '['}:
                    if brace_depth == 1:
                        start = i
                    brace_depth += 1
                elif tokens[i].src in {')', '}', ']'}:
                    brace_depth -= 1
                    if brace_depth == 1:
                        end = i
                i += 1

            assert start != -1
            assert end != -1
            call_end = i - 1

            # dedent everything inside the brackets
            for i in range(call_start, call_end):
                if (tokens[i - 1].name == 'NL'
                        and tokens[i].name == UNIMPORTANT_WS):
                    tokens[i] = tokens[i]._replace(src=tokens[i].src[4:])

            del tokens[end + 1:call_end]
            del tokens[call_start + 1:start]

    return tokens_to_src(tokens)
Example #13
0
def fix_file(filename: str) -> int:
    with open(filename, 'rb') as f:
        contents_bytes = f.read()

    try:
        contents_text = contents_bytes.decode()
    except UnicodeDecodeError:
        print(f'{filename} is non-utf8 (not supported)')
        return 1

    tokens = tokenize_rt.src_to_tokens(contents_text)

    tokens_no_comments = _remove_comments(tokens)
    src_no_comments = tokenize_rt.tokens_to_src(tokens_no_comments)

    if src_no_comments == contents_text:
        return 0

    fd, path = tempfile.mkstemp(
        dir=os.path.dirname(filename),
        prefix=os.path.basename(filename),
        suffix='.py',
    )
    try:
        with open(fd, 'wb') as f:
            f.write(src_no_comments.encode())
        flake8_results = _run_flake8(path)
    finally:
        os.remove(path)

    if any('E999' in v for v in flake8_results.values()):
        print(f'{filename}: syntax error (skipping)')
        return 0

    for i, token in tokenize_rt.reversed_enumerate(tokens):
        if token.name != 'COMMENT':
            continue

        if NOQA_RE.search(token.src):
            _rewrite_noqa_comment(tokens, i, flake8_results)
        elif NOQA_FILE_RE.match(token.src) and not flake8_results:
            if i == 0 or tokens[i - 1].name == 'NEWLINE':
                del tokens[i:i + 2]
            else:
                _remove_comment(tokens, i)

    newsrc = tokenize_rt.tokens_to_src(tokens)
    if newsrc != contents_text:
        print(f'Rewriting {filename}')
        with open(filename, 'wb') as f:
            f.write(newsrc.encode())
        return 1
    else:
        return 0
Example #14
0
def test_reversed_enumerate():
    tokens = src_to_tokens('x = 5\n')
    ret = reversed_enumerate(tokens)
    assert next(ret) == (6, Token('ENDMARKER', '', line=2, utf8_byte_offset=0))

    rest = list(ret)
    assert rest == [
        (5, Token(name='NEWLINE', src='\n', line=1, utf8_byte_offset=5)),
        (4, Token('NUMBER', '5', line=1, utf8_byte_offset=4)),
        (3, Token(UNIMPORTANT_WS, ' ')),
        (2, Token('OP', '=', line=1, utf8_byte_offset=2)),
        (1, Token(UNIMPORTANT_WS, ' ')),
        (0, Token('NAME', 'x', line=1, utf8_byte_offset=0)),
    ]
def _fix_dictcomps(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text
    visitor = FindDictsVisitor()
    visitor.visit(ast_obj)
    if not visitor.dicts:
        return contents_text

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.dicts:
            _process_dict_comp(tokens, i, visitor.dicts[token.offset])
    return tokens_to_src(tokens)
Example #16
0
def _mutate_found(tokens: List[Token], visitor: _FindAssignment) -> None:
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.ctx_kwargs:
            brace_start = i + 1
            brace_end = _find_closing_brace(tokens, brace_start, "(")
            visitor.ctx_kwargs.remove(token.offset)
        elif token.offset in visitor.ctx_returned:
            return_start = i
            if not return_start < brace_start < brace_end:  # pragma: no cover
                raise Exception
            inserted = _process_ctx_returned(tokens, return_start, brace_start,
                                             brace_end)
            _process_ctx_kwargs(tokens, brace_start + inserted,
                                brace_end + inserted)
            visitor.ctx_returned.remove(token.offset)
def _fix_sets(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text
    visitor = FindSetsVisitor()
    visitor.visit(ast_obj)
    if not visitor.sets and not visitor.set_empty_literals:
        return contents_text

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.set_empty_literals:
            _process_set_empty_literal(tokens, i)
        elif token.offset in visitor.sets:
            _process_set_literal(tokens, i, visitor.sets[token.offset])
    return tokens_to_src(tokens)
Example #18
0
def replace_inconsistent_pandas_namespace(visitor: Visitor,
                                          content: str) -> str:
    from tokenize_rt import (
        reversed_enumerate,
        src_to_tokens,
        tokens_to_src,
    )

    tokens = src_to_tokens(content)
    for n, i in reversed_enumerate(tokens):
        if (i.offset in visitor.pandas_namespace and
                visitor.pandas_namespace[i.offset] in visitor.no_namespace):
            # Replace `pd`
            tokens[n] = i._replace(src="")
            # Replace `.`
            tokens[n + 1] = tokens[n + 1]._replace(src="")

    new_src: str = tokens_to_src(tokens)
    return new_src
Example #19
0
def _restore_semicolon(
    source: str,
    cell_number: int,
    trailing_semicolons: Set[int],
) -> str:
    """
    Restore the semicolon at the end of the cell.

    Restore the trailing semicolon if the cell originally contained semicolon
    and the third party tool removed it.

    Parameters
    ----------
    source
        Portion of Python file between cell separators.
    cell_number
        Number of current cell.
    trailing_semicolons
        List of cells which originally had trailing semicolons.

    Returns
    -------
    str
        New source with removed semicolon restored.

    Raises
    ------
    AssertionError
        If code thought to be unreachable is reached.
    """
    if cell_number in trailing_semicolons:
        tokens = tokenize_rt.src_to_tokens(source)
        for idx, token in tokenize_rt.reversed_enumerate(tokens):
            if not token.src.strip(" \n") or token.name == "COMMENT":
                continue
            tokens[idx] = token._replace(src=token.src + ";")
            break
        else:  # pragma: nocover
            raise AssertionError(
                "Unreachable code, please report bug at https://github.com/nbQA-dev/nbQA/issues"
            )
        source = tokenize_rt.tokens_to_src(tokens)
    return source
def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:
    """Put trailing semicolon back if cell originally had it.

    Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
    ``tokenize_rt`` so that round-tripping works fine.
    """
    if not has_trailing_semicolon:
        return src
    from tokenize_rt import src_to_tokens, tokens_to_src, reversed_enumerate

    tokens = src_to_tokens(src)
    for idx, token in reversed_enumerate(tokens):
        if token.name in TOKENS_TO_IGNORE:
            continue
        tokens[idx] = token._replace(src=token.src + ";")
        break
    else:  # pragma: nocover
        raise AssertionError(
            "INTERNAL ERROR: Was not able to reinstate trailing semicolon. "
            "Please report a bug on https://github.com/psf/black/issues.  "
        ) from None
    return str(tokens_to_src(tokens))
Example #21
0
def _fix_super(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindSuper()
    visitor.visit(ast_obj)

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        call = visitor.found.get(token.offset)
        if not call:
            continue

        while tokens[i].name != 'OP':
            i += 1

        victims = _victims(tokens, i, call, gen=False)
        del tokens[victims.starts[0] + 1:victims.ends[-1]]

    return tokens_to_src(tokens)
Example #22
0
def _fix_tokens(contents_text, py3_plus):
    remove_u_prefix = py3_plus or _imports_unicode_literals(contents_text)

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.name == 'NUMBER':
            tokens[i] = token._replace(src=_fix_long(_fix_octal(token.src)))
        elif token.name == 'STRING':
            # when a string prefix is not recognized, the tokenizer produces a
            # NAME token followed by a STRING token
            if i > 0 and _is_string_prefix(tokens[i - 1]):
                tokens[i] = token._replace(src=tokens[i - 1].src + token.src)
                tokens[i - 1] = tokens[i - 1]._replace(src='')

            tokens[i] = _fix_ur_literals(tokens[i])
            if remove_u_prefix:
                tokens[i] = _remove_u_prefix(tokens[i])
            tokens[i] = _fix_escape_sequences(tokens[i])
        elif token.src == '(':
            _fix_extraneous_parens(tokens, i)
    return tokens_to_src(tokens)
def decode(b: bytes, errors: str = 'strict') -> Tuple[str, int]:
    u, length = utf_8.decode(b, errors)

    # replace encoding cookie so there isn't a recursion problem
    lines = u.splitlines(True)
    for idx in (0, 1):
        if idx >= len(lines):
            break
        lines[idx] = tokenize.cookie_re.sub(_new_coding_cookie, lines[idx])
    u = ''.join(lines)

    visitor = Visitor()
    visitor.visit(_ast_parse(u))

    tokens = tokenize_rt.src_to_tokens(u)
    for i, token in tokenize_rt.reversed_enumerate(tokens):
        if token.offset in visitor.offsets:
            # look forward for a `:`, `,`, `=`, ')'
            depth = 0
            j = i + 1
            while depth or tokens[j].src not in {':', ',', '=', ')', '\n'}:
                if tokens[j].src in {'(', '{', '['}:
                    depth += 1
                elif tokens[j].src in {')', '}', ']'}:
                    depth -= 1
                j += 1
            j -= 1

            # look backward to delete whitespace / comments / etc.
            while tokens[j].name in tokenize_rt.NON_CODING_TOKENS:
                j -= 1

            quoted = repr(tokenize_rt.tokens_to_src(tokens[i:j + 1]))
            tokens[i:j + 1] = [tokenize_rt.Token('STRING', quoted)]

    return tokenize_rt.tokens_to_src(tokens), length
Example #24
0
def _fix_py36_plus(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPy36Plus()
    visitor.visit(ast_obj)

    if not any((
            visitor.fstrings,
            visitor.named_tuples,
            visitor.dict_typed_dicts,
            visitor.kw_typed_dicts,
    )):
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.fstrings:
            node = visitor.fstrings[token.offset]

            # TODO: handle \N escape sequences
            if r'\N' in token.src:
                continue

            paren = i + 3
            if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
                continue

            # we don't actually care about arg position, so we pass `node`
            fmt_victims = victims(tokens, paren, node, gen=False)
            end = fmt_victims.ends[-1]
            # if it spans more than one line, bail
            if tokens[end].line != token.line:
                continue

            tokens[i] = token._replace(src=_to_fstring(token.src, node))
            del tokens[i + 1:end + 1]
        elif token.offset in visitor.named_tuples and token.name == 'NAME':
            call = visitor.named_tuples[token.offset]
            types: Dict[str, ast.expr] = {
                tup.elts[0].s: tup.elts[1]  # type: ignore  # (checked above)
                for tup in call.args[1].elts  # type: ignore  # (checked above)
            }
            _replace_typed_class(tokens, i, call, types)
        elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME':
            call = visitor.kw_typed_dicts[token.offset]
            types = {
                arg.arg: arg.value  # type: ignore  # (checked above)
                for arg in call.keywords
            }
            _replace_typed_class(tokens, i, call, types)
        elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME':
            call = visitor.dict_typed_dicts[token.offset]
            types = {
                k.s: v  # type: ignore  # (checked above)
                for k, v in zip(
                    call.args[1].keys,  # type: ignore  # (checked above)
                    call.args[1].values,  # type: ignore  # (checked above)
                )
            }
            _replace_typed_class(tokens, i, call, types)

    return tokens_to_src(tokens)
Example #25
0
def _fix_py36_plus(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPy36Plus()
    visitor.visit(ast_obj)

    if not any((
            visitor.fstrings,
            visitor.named_tuples,
            visitor.dict_typed_dicts,
            visitor.kw_typed_dicts,
    )):
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.fstrings:
            # TODO: handle \N escape sequences
            if r'\N' in token.src:
                continue

            paren = i + 3
            if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
                continue

            args, end = parse_call_args(tokens, paren)
            # if it spans more than one line, bail
            if tokens[end - 1].line != token.line:
                continue

            args_src = tokens_to_src(tokens[paren:end])
            if '\\' in args_src or '"' in args_src or "'" in args_src:
                continue

            tokens[i] = token._replace(
                src=_to_fstring(token.src, tokens, args),
            )
            del tokens[i + 1:end]
        elif token.offset in visitor.named_tuples and token.name == 'NAME':
            call = visitor.named_tuples[token.offset]
            types: Dict[str, ast.expr] = {
                tup.elts[0].s: tup.elts[1]  # type: ignore  # (checked above)
                for tup in call.args[1].elts  # type: ignore  # (checked above)
            }
            end, attrs = _typed_class_replacement(tokens, i, call, types)
            src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
            tokens[i:end] = [Token('CODE', src)]
        elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME':
            call = visitor.kw_typed_dicts[token.offset]
            types = {
                arg.arg: arg.value  # type: ignore  # (checked above)
                for arg in call.keywords
            }
            end, attrs = _typed_class_replacement(tokens, i, call, types)
            src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
            tokens[i:end] = [Token('CODE', src)]
        elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME':
            call = visitor.dict_typed_dicts[token.offset]
            types = {
                k.s: v  # type: ignore  # (checked above)
                for k, v in zip(
                    call.args[1].keys,  # type: ignore  # (checked above)
                    call.args[1].values,  # type: ignore  # (checked above)
                )
            }
            if call.keywords:
                total = call.keywords[0].value.value  # type: ignore # (checked above)  # noqa: E501
                end, attrs = _typed_class_replacement(tokens, i, call, types)
                src = (
                    f'class {tokens[i].src}('
                    f'{_unparse(call.func)}, total={total}'
                    f'):\n'
                    f'{attrs}'
                )
                tokens[i:end] = [Token('CODE', src)]
            else:
                end, attrs = _typed_class_replacement(tokens, i, call, types)
                src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
                tokens[i:end] = [Token('CODE', src)]

    return tokens_to_src(tokens)
Example #26
0
def _fix_py3_plus(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPy3Plus()
    visitor.visit(ast_obj)

    if not any((
            visitor.bases_to_remove,
            visitor.six_b,
            visitor.six_calls,
            visitor.six_raises,
            visitor.six_remove_decorators,
            visitor.six_simple,
            visitor.six_type_ctx,
            visitor.six_with_metaclass,
            visitor.super_calls,
    )):
        return contents_text

    def _replace(i, mapping, node):
        new_token = Token('CODE', _get_tmpl(mapping, node))
        if isinstance(node, ast.Name):
            tokens[i] = new_token
        else:
            j = i
            while tokens[j].src != node.attr:
                j += 1
            tokens[i:j + 1] = [new_token]

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if not token.src:
            continue
        elif token.offset in visitor.bases_to_remove:
            _remove_base_class(tokens, i)
        elif token.offset in visitor.six_type_ctx:
            _replace(i, SIX_TYPE_CTX_ATTRS, visitor.six_type_ctx[token.offset])
        elif token.offset in visitor.six_simple:
            _replace(i, SIX_SIMPLE_ATTRS, visitor.six_simple[token.offset])
        elif token.offset in visitor.six_remove_decorators:
            if tokens[i - 1].src == '@':
                end = i + 1
                while tokens[end].name != 'NEWLINE':
                    end += 1
                del tokens[i - 1:end + 1]
        elif token.offset in visitor.six_b:
            j = _find_open_paren(tokens, i)
            if (tokens[j + 1].name == 'STRING' and _is_ascii(tokens[j + 1].src)
                    and tokens[j + 2].src == ')'):
                func_args, end = _parse_call_args(tokens, j)
                _replace_call(tokens, i, end, func_args, SIX_B_TMPL)
        elif token.offset in visitor.six_calls:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            node = visitor.six_calls[token.offset]
            template = _get_tmpl(SIX_CALLS, node.func)
            _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.six_raises:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            node = visitor.six_raises[token.offset]
            template = _get_tmpl(SIX_RAISES, node.func)
            _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.six_with_metaclass:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            if len(func_args) == 1:
                tmpl = WITH_METACLASS_NO_BASES_TMPL
            else:
                tmpl = WITH_METACLASS_BASES_TMPL
            _replace_call(tokens, i, end, func_args, tmpl)
        elif token.offset in visitor.super_calls:
            i = _find_open_paren(tokens, i)
            call = visitor.super_calls[token.offset]
            victims = _victims(tokens, i, call, gen=False)
            del tokens[victims.starts[0] + 1:victims.ends[-1]]

    return tokens_to_src(tokens)
Example #27
0
def _fix_new_style_classes(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindNewStyleClasses()
    visitor.visit(ast_obj)

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        base = visitor.found.get(token.offset)
        if not base:
            continue

        # single base, look forward until the colon to find the ), then  look
        # backward to find the matching (
        if (len(base.node.bases) == 1
                and not getattr(base.node, 'keywords', None)):
            j = i
            while tokens[j].src != ':':
                j += 1
            while tokens[j].src != ')':
                j -= 1

            end_index = j
            brace_stack = [')']
            while brace_stack:
                j -= 1
                if tokens[j].src == ')':
                    brace_stack.append(')')
                elif tokens[j].src == '(':
                    brace_stack.pop()
            start_index = j

            del tokens[start_index:end_index + 1]
        # multiple bases, look forward and remove a comma
        elif base.index == 0:
            j = i
            brace_stack = []
            while tokens[j].src != ',':
                if tokens[j].src == ')':
                    brace_stack.append(')')
                j += 1
            end_index = j

            j = i
            while brace_stack:
                j -= 1
                if tokens[j].src == '(':
                    brace_stack.pop()
            start_index = j

            # if there's space afterwards remove that too
            if tokens[end_index + 1].name == UNIMPORTANT_WS:
                end_index += 1

            # if it is on its own line, remove it
            if (tokens[start_index - 1].name == UNIMPORTANT_WS
                    and tokens[start_index - 2].name == 'NL'
                    and tokens[end_index + 1].name == 'NL'):
                start_index -= 1
                end_index += 1

            del tokens[start_index:end_index + 1]
        # multiple bases, look backward and remove a comma
        else:
            j = i
            brace_stack = []
            while tokens[j].src != ',':
                if tokens[j].src == '(':
                    brace_stack.append('(')
                j -= 1
            start_index = j

            j = i
            while brace_stack:
                j += 1
                if tokens[j].src == ')':
                    brace_stack.pop()
            end_index = j

            del tokens[start_index:end_index + 1]
    return tokens_to_src(tokens)
Example #28
0
def _fix_six(contents_text):
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindSixUsage()
    visitor.visit(ast_obj)

    def _replace_name(i, mapping, node):
        tokens[i] = Token('CODE', mapping[node.id])

    def _replace_attr(i, mapping, node):
        if tokens[i + 1].src == '.' and tokens[i + 2].src == node.attr:
            tokens[i:i + 3] = [Token('CODE', mapping[node.attr])]

    tokens = src_to_tokens(contents_text)
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.type_ctx_names:
            node = visitor.type_ctx_names[token.offset]
            _replace_name(i, SIX_TYPE_CTX_ATTRS, node)
        elif token.offset in visitor.type_ctx_attrs:
            node = visitor.type_ctx_attrs[token.offset]
            _replace_attr(i, SIX_TYPE_CTX_ATTRS, node)
        elif token.offset in visitor.simple_names:
            node = visitor.simple_names[token.offset]
            _replace_name(i, SIX_SIMPLE_ATTRS, node)
        elif token.offset in visitor.simple_attrs:
            node = visitor.simple_attrs[token.offset]
            _replace_attr(i, SIX_SIMPLE_ATTRS, node)
        elif token.offset in visitor.remove_decorators:
            if tokens[i - 1].src == '@':
                end = i + 1
                while tokens[end].name != 'NEWLINE':
                    end += 1
                del tokens[i - 1:end + 1]
        elif token.offset in visitor.call_names:
            node = visitor.call_names[token.offset]
            if tokens[i + 1].src == '(':
                func_args, end = _parse_call_args(tokens, i + 1)
                template = SIX_CALLS[node.func.id]
                _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.call_attrs:
            node = visitor.call_attrs[token.offset]
            if (tokens[i + 1].src == '.'
                    and tokens[i + 2].src == node.func.attr
                    and tokens[i + 3].src == '('):
                func_args, end = _parse_call_args(tokens, i + 3)
                template = SIX_CALLS[node.func.attr]
                _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.raise_names:
            node = visitor.raise_names[token.offset]
            if tokens[i + 1].src == '(':
                func_args, end = _parse_call_args(tokens, i + 1)
                template = SIX_RAISES[node.func.id]
                _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.raise_attrs:
            node = visitor.raise_attrs[token.offset]
            if (tokens[i + 1].src == '.'
                    and tokens[i + 2].src == node.func.attr
                    and tokens[i + 3].src == '('):
                func_args, end = _parse_call_args(tokens, i + 3)
                template = SIX_RAISES[node.func.attr]
                _replace_call(tokens, i, end, func_args, template)

    return tokens_to_src(tokens)