def _fix_tokens(contents_text: str, min_version: Version) -> str: remove_u = (min_version >= (3, ) or _imports_future(contents_text, 'unicode_literals')) try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: return contents_text for i, token in reversed_enumerate(tokens): if token.name == 'NUMBER': tokens[i] = token._replace(src=_fix_long(_fix_octal(token.src))) elif token.name == 'STRING': tokens[i] = _fix_ur_literals(tokens[i]) if remove_u: tokens[i] = _remove_u_prefix(tokens[i]) tokens[i] = _fix_escape_sequences(tokens[i]) elif token.src == '(': _fix_extraneous_parens(tokens, i) elif token.src == 'format' and i > 0 and tokens[i - 1].src == '.': _fix_format_literal(tokens, i - 2) elif token.src == 'encode' and i > 0 and tokens[i - 1].src == '.': _fix_encode_to_binary(tokens, i) elif (min_version >= (3, ) and token.utf8_byte_offset == 0 and token.line < 3 and token.name == 'COMMENT' and tokenize.cookie_re.match(token.src)): del tokens[i] assert tokens[i].name == 'NL', tokens[i].name del tokens[i] elif token.src == 'from' and token.utf8_byte_offset == 0: _fix_import_removals(tokens, i, min_version) return tokens_to_src(tokens).lstrip()
def _fix_six(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSixUsage() visitor.visit(ast_obj) tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.simple_names: node = visitor.simple_names[token.offset] tokens[i] = Token('CODE', SIX_SIMPLE_ATTRS[node.id]) elif token.offset in visitor.simple_attrs: node = visitor.simple_attrs[token.offset] if tokens[i + 1].src == '.' and tokens[i + 2].src == node.attr: tokens[i:i + 3] = [Token('CODE', SIX_SIMPLE_ATTRS[node.attr])] elif token.offset in visitor.remove_decorators: if tokens[i - 1].src == '@': end = i + 1 while tokens[end].name != 'NEWLINE': end += 1 del tokens[i - 1:end + 1] return tokens_to_src(tokens)
def fix_file(filename: str, show_diff: bool = False, dry_run: bool = False) -> int: with open(filename, 'rb') as f: contents_bytes = f.read() try: contents_text = contents_bytes.decode() except UnicodeDecodeError: print(f'{filename} is non-utf8 (not supported)') return 1 tokens = tokenize_rt.src_to_tokens(contents_text) tokens_no_comments = _remove_comments(tokens) src_no_comments = tokenize_rt.tokens_to_src(tokens_no_comments) if src_no_comments == contents_text: return 0 with tempfile.NamedTemporaryFile( dir=os.path.dirname(filename), prefix=os.path.basename(filename), suffix='.py', ) as tmpfile: tmpfile.write(src_no_comments.encode()) tmpfile.flush() flake8_results = _run_flake8(tmpfile.name) if any('E999' in v for v in flake8_results.values()): print(f'{filename}: syntax error (skipping)') return 0 for i, token in tokenize_rt.reversed_enumerate(tokens): if token.name != 'COMMENT': continue if NOQA_RE.search(token.src): _rewrite_noqa_comment(tokens, i, flake8_results) elif NOQA_FILE_RE.match(token.src) and not flake8_results: if i == 0 or tokens[i - 1].name == 'NEWLINE': del tokens[i: i + 2] else: _remove_comment(tokens, i) newsrc = tokenize_rt.tokens_to_src(tokens) if newsrc != contents_text: if (show_diff or dry_run): diff = difflib.unified_diff( contents_text.splitlines(keepends=True), newsrc.splitlines(keepends=True), fromfile=filename, tofile=filename, ) print(''.join(diff), end='') if (not dry_run): print(f'Rewriting {filename}') with open(filename, 'wb') as f: f.write(newsrc.encode()) return 1 else: return 0
def _fix_py2_compatible(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = Py2CompatibleVisitor() visitor.visit(ast_obj) if not any(( visitor.dicts, visitor.sets, visitor.set_empty_literals, visitor.is_literal, )): return contents_text tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.dicts: _process_dict_comp(tokens, i, visitor.dicts[token.offset]) elif token.offset in visitor.set_empty_literals: _process_set_empty_literal(tokens, i) elif token.offset in visitor.sets: _process_set_literal(tokens, i, visitor.sets[token.offset]) elif token.offset in visitor.is_literal: _process_is_literal(tokens, i, visitor.is_literal[token.offset]) return tokens_to_src(tokens)
def _fix_percent_format(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindPercentFormats() visitor.visit(ast_obj) if not visitor.found: return contents_text tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): node = visitor.found.get(token.offset) if node is None: continue # no .format() equivalent for bytestrings in py3 # note that this code is only necessary when running in python2 if _is_bytestring(tokens[i].src): # pragma: no cover (py2-only) continue if isinstance(node.right, ast.Tuple): _fix_percent_format_tuple(tokens, i, node) elif isinstance(node.right, ast.Dict): _fix_percent_format_dict(tokens, i, node) return tokens_to_src(tokens)
def _fix_plugins(contents_text: str, settings: Settings) -> str: try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text callbacks = visit(FUNCS, ast_obj, settings) if not callbacks: return contents_text try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: # pragma: no cover (bpo-2180) return contents_text _fixup_dedent_tokens(tokens) for i, token in reversed_enumerate(tokens): if not token.src: continue # though this is a defaultdict, by using `.get()` this function's # self time is almost 50% faster for callback in callbacks.get(token.offset, ()): callback(i, tokens) return tokens_to_src(tokens)
def remove_trailing_semicolon(src: str) -> Tuple[str, bool]: """Remove trailing semicolon from Jupyter notebook cell. For example, fig, ax = plt.subplots() ax.plot(x_data, y_data); # plot data would become fig, ax = plt.subplots() ax.plot(x_data, y_data) # plot data Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses ``tokenize_rt`` so that round-tripping works fine. """ from tokenize_rt import ( src_to_tokens, tokens_to_src, reversed_enumerate, ) tokens = src_to_tokens(src) trailing_semicolon = False for idx, token in reversed_enumerate(tokens): if token.name in TOKENS_TO_IGNORE: continue if token.name == "OP" and token.src == ";": del tokens[idx] trailing_semicolon = True break if not trailing_semicolon: return src, False return tokens_to_src(tokens), True
def _fix_fstrings(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSimpleFormats() visitor.visit(ast_obj) tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): node = visitor.found.get(token.offset) if node is None: continue if _is_bytestring(token.src): # pragma: no cover (py2-only) continue paren = i + 3 if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(': continue # we don't actually care about arg position, so we pass `node` victims = _victims(tokens, paren, node, gen=False) end = victims.ends[-1] # if it spans more than one line, bail if tokens[end].line != token.line: continue tokens[i] = token._replace(src=_to_fstring(token.src, node)) del tokens[i + 1:end + 1] return tokens_to_src(tokens)
def _has_trailing_semicolon(src: str) -> Tuple[str, bool]: """ Check if cell has trailing semicolon. Parameters ---------- src Notebook cell source. Returns ------- bool Whether notebook has trailing semicolon. """ tokens = tokenize_rt.src_to_tokens(src) trailing_semicolon = False for idx, token in tokenize_rt.reversed_enumerate(tokens): if not token.src.strip(" \n") or token.name == "COMMENT": continue if token.name == "OP" and token.src == ";": tokens[idx] = token._replace(src="") trailing_semicolon = True break if not trailing_semicolon: return src, False return tokenize_rt.tokens_to_src(tokens), True
def _remove_comments(tokens: Tokens) -> Tokens: tokens = list(tokens) for i, token in tokenize_rt.reversed_enumerate(tokens): if token.name == 'COMMENT': if NOQA_RE.search(token.src): _rewrite_noqa_comment(tokens, i, collections.defaultdict(set)) elif NOQA_FILE_RE.search(token.src): _remove_comment(tokens, i) return tokens
def _remove_comments(tokens: Tokens) -> Tokens: tokens = list(tokens) for i, token in tokenize_rt.reversed_enumerate(tokens): if token.name == 'COMMENT': if NOQA_RE.search(token.src): _mask_noqa_comment(tokens, i) elif NOQA_FILE_RE.search(token.src): _remove_comment(tokens, i) return tokens
def _fix_calls(contents_text: str) -> str: try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = Visitor() visitor.visit(ast_obj) if not visitor.calls: return contents_text try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: # pragma: no cover (bpo-2180) return contents_text for i, token in reversed_enumerate(tokens): if token.offset in visitor.calls: visitor.calls.discard(token.offset) # search forward for the opening brace while tokens[i].src != '(': i += 1 call_start = i i += 1 brace_depth = 1 start = -1 end = -1 while brace_depth: if tokens[i].src in {'(', '{', '['}: if brace_depth == 1: start = i brace_depth += 1 elif tokens[i].src in {')', '}', ']'}: brace_depth -= 1 if brace_depth == 1: end = i i += 1 assert start != -1 assert end != -1 call_end = i - 1 # dedent everything inside the brackets for i in range(call_start, call_end): if (tokens[i - 1].name == 'NL' and tokens[i].name == UNIMPORTANT_WS): tokens[i] = tokens[i]._replace(src=tokens[i].src[4:]) del tokens[end + 1:call_end] del tokens[call_start + 1:start] return tokens_to_src(tokens)
def fix_file(filename: str) -> int: with open(filename, 'rb') as f: contents_bytes = f.read() try: contents_text = contents_bytes.decode() except UnicodeDecodeError: print(f'{filename} is non-utf8 (not supported)') return 1 tokens = tokenize_rt.src_to_tokens(contents_text) tokens_no_comments = _remove_comments(tokens) src_no_comments = tokenize_rt.tokens_to_src(tokens_no_comments) if src_no_comments == contents_text: return 0 fd, path = tempfile.mkstemp( dir=os.path.dirname(filename), prefix=os.path.basename(filename), suffix='.py', ) try: with open(fd, 'wb') as f: f.write(src_no_comments.encode()) flake8_results = _run_flake8(path) finally: os.remove(path) if any('E999' in v for v in flake8_results.values()): print(f'{filename}: syntax error (skipping)') return 0 for i, token in tokenize_rt.reversed_enumerate(tokens): if token.name != 'COMMENT': continue if NOQA_RE.search(token.src): _rewrite_noqa_comment(tokens, i, flake8_results) elif NOQA_FILE_RE.match(token.src) and not flake8_results: if i == 0 or tokens[i - 1].name == 'NEWLINE': del tokens[i:i + 2] else: _remove_comment(tokens, i) newsrc = tokenize_rt.tokens_to_src(tokens) if newsrc != contents_text: print(f'Rewriting {filename}') with open(filename, 'wb') as f: f.write(newsrc.encode()) return 1 else: return 0
def test_reversed_enumerate(): tokens = src_to_tokens('x = 5\n') ret = reversed_enumerate(tokens) assert next(ret) == (6, Token('ENDMARKER', '', line=2, utf8_byte_offset=0)) rest = list(ret) assert rest == [ (5, Token(name='NEWLINE', src='\n', line=1, utf8_byte_offset=5)), (4, Token('NUMBER', '5', line=1, utf8_byte_offset=4)), (3, Token(UNIMPORTANT_WS, ' ')), (2, Token('OP', '=', line=1, utf8_byte_offset=2)), (1, Token(UNIMPORTANT_WS, ' ')), (0, Token('NAME', 'x', line=1, utf8_byte_offset=0)), ]
def _fix_dictcomps(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindDictsVisitor() visitor.visit(ast_obj) if not visitor.dicts: return contents_text tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.dicts: _process_dict_comp(tokens, i, visitor.dicts[token.offset]) return tokens_to_src(tokens)
def _mutate_found(tokens: List[Token], visitor: _FindAssignment) -> None: for i, token in reversed_enumerate(tokens): if token.offset in visitor.ctx_kwargs: brace_start = i + 1 brace_end = _find_closing_brace(tokens, brace_start, "(") visitor.ctx_kwargs.remove(token.offset) elif token.offset in visitor.ctx_returned: return_start = i if not return_start < brace_start < brace_end: # pragma: no cover raise Exception inserted = _process_ctx_returned(tokens, return_start, brace_start, brace_end) _process_ctx_kwargs(tokens, brace_start + inserted, brace_end + inserted) visitor.ctx_returned.remove(token.offset)
def _fix_sets(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSetsVisitor() visitor.visit(ast_obj) if not visitor.sets and not visitor.set_empty_literals: return contents_text tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.set_empty_literals: _process_set_empty_literal(tokens, i) elif token.offset in visitor.sets: _process_set_literal(tokens, i, visitor.sets[token.offset]) return tokens_to_src(tokens)
def replace_inconsistent_pandas_namespace(visitor: Visitor, content: str) -> str: from tokenize_rt import ( reversed_enumerate, src_to_tokens, tokens_to_src, ) tokens = src_to_tokens(content) for n, i in reversed_enumerate(tokens): if (i.offset in visitor.pandas_namespace and visitor.pandas_namespace[i.offset] in visitor.no_namespace): # Replace `pd` tokens[n] = i._replace(src="") # Replace `.` tokens[n + 1] = tokens[n + 1]._replace(src="") new_src: str = tokens_to_src(tokens) return new_src
def _restore_semicolon( source: str, cell_number: int, trailing_semicolons: Set[int], ) -> str: """ Restore the semicolon at the end of the cell. Restore the trailing semicolon if the cell originally contained semicolon and the third party tool removed it. Parameters ---------- source Portion of Python file between cell separators. cell_number Number of current cell. trailing_semicolons List of cells which originally had trailing semicolons. Returns ------- str New source with removed semicolon restored. Raises ------ AssertionError If code thought to be unreachable is reached. """ if cell_number in trailing_semicolons: tokens = tokenize_rt.src_to_tokens(source) for idx, token in tokenize_rt.reversed_enumerate(tokens): if not token.src.strip(" \n") or token.name == "COMMENT": continue tokens[idx] = token._replace(src=token.src + ";") break else: # pragma: nocover raise AssertionError( "Unreachable code, please report bug at https://github.com/nbQA-dev/nbQA/issues" ) source = tokenize_rt.tokens_to_src(tokens) return source
def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str: """Put trailing semicolon back if cell originally had it. Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses ``tokenize_rt`` so that round-tripping works fine. """ if not has_trailing_semicolon: return src from tokenize_rt import src_to_tokens, tokens_to_src, reversed_enumerate tokens = src_to_tokens(src) for idx, token in reversed_enumerate(tokens): if token.name in TOKENS_TO_IGNORE: continue tokens[idx] = token._replace(src=token.src + ";") break else: # pragma: nocover raise AssertionError( "INTERNAL ERROR: Was not able to reinstate trailing semicolon. " "Please report a bug on https://github.com/psf/black/issues. " ) from None return str(tokens_to_src(tokens))
def _fix_super(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSuper() visitor.visit(ast_obj) tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): call = visitor.found.get(token.offset) if not call: continue while tokens[i].name != 'OP': i += 1 victims = _victims(tokens, i, call, gen=False) del tokens[victims.starts[0] + 1:victims.ends[-1]] return tokens_to_src(tokens)
def _fix_tokens(contents_text, py3_plus): remove_u_prefix = py3_plus or _imports_unicode_literals(contents_text) try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: return contents_text for i, token in reversed_enumerate(tokens): if token.name == 'NUMBER': tokens[i] = token._replace(src=_fix_long(_fix_octal(token.src))) elif token.name == 'STRING': # when a string prefix is not recognized, the tokenizer produces a # NAME token followed by a STRING token if i > 0 and _is_string_prefix(tokens[i - 1]): tokens[i] = token._replace(src=tokens[i - 1].src + token.src) tokens[i - 1] = tokens[i - 1]._replace(src='') tokens[i] = _fix_ur_literals(tokens[i]) if remove_u_prefix: tokens[i] = _remove_u_prefix(tokens[i]) tokens[i] = _fix_escape_sequences(tokens[i]) elif token.src == '(': _fix_extraneous_parens(tokens, i) return tokens_to_src(tokens)
def decode(b: bytes, errors: str = 'strict') -> Tuple[str, int]: u, length = utf_8.decode(b, errors) # replace encoding cookie so there isn't a recursion problem lines = u.splitlines(True) for idx in (0, 1): if idx >= len(lines): break lines[idx] = tokenize.cookie_re.sub(_new_coding_cookie, lines[idx]) u = ''.join(lines) visitor = Visitor() visitor.visit(_ast_parse(u)) tokens = tokenize_rt.src_to_tokens(u) for i, token in tokenize_rt.reversed_enumerate(tokens): if token.offset in visitor.offsets: # look forward for a `:`, `,`, `=`, ')' depth = 0 j = i + 1 while depth or tokens[j].src not in {':', ',', '=', ')', '\n'}: if tokens[j].src in {'(', '{', '['}: depth += 1 elif tokens[j].src in {')', '}', ']'}: depth -= 1 j += 1 j -= 1 # look backward to delete whitespace / comments / etc. while tokens[j].name in tokenize_rt.NON_CODING_TOKENS: j -= 1 quoted = repr(tokenize_rt.tokens_to_src(tokens[i:j + 1])) tokens[i:j + 1] = [tokenize_rt.Token('STRING', quoted)] return tokenize_rt.tokens_to_src(tokens), length
def _fix_py36_plus(contents_text: str) -> str: try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindPy36Plus() visitor.visit(ast_obj) if not any(( visitor.fstrings, visitor.named_tuples, visitor.dict_typed_dicts, visitor.kw_typed_dicts, )): return contents_text try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: # pragma: no cover (bpo-2180) return contents_text for i, token in reversed_enumerate(tokens): if token.offset in visitor.fstrings: node = visitor.fstrings[token.offset] # TODO: handle \N escape sequences if r'\N' in token.src: continue paren = i + 3 if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(': continue # we don't actually care about arg position, so we pass `node` fmt_victims = victims(tokens, paren, node, gen=False) end = fmt_victims.ends[-1] # if it spans more than one line, bail if tokens[end].line != token.line: continue tokens[i] = token._replace(src=_to_fstring(token.src, node)) del tokens[i + 1:end + 1] elif token.offset in visitor.named_tuples and token.name == 'NAME': call = visitor.named_tuples[token.offset] types: Dict[str, ast.expr] = { tup.elts[0].s: tup.elts[1] # type: ignore # (checked above) for tup in call.args[1].elts # type: ignore # (checked above) } _replace_typed_class(tokens, i, call, types) elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME': call = visitor.kw_typed_dicts[token.offset] types = { arg.arg: arg.value # type: ignore # (checked above) for arg in call.keywords } _replace_typed_class(tokens, i, call, types) elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME': call = visitor.dict_typed_dicts[token.offset] types = { k.s: v # type: ignore # (checked above) for k, v in zip( call.args[1].keys, # type: ignore # (checked above) call.args[1].values, # type: ignore # (checked above) ) } _replace_typed_class(tokens, i, call, types) return tokens_to_src(tokens)
def _fix_py36_plus(contents_text: str) -> str: try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindPy36Plus() visitor.visit(ast_obj) if not any(( visitor.fstrings, visitor.named_tuples, visitor.dict_typed_dicts, visitor.kw_typed_dicts, )): return contents_text try: tokens = src_to_tokens(contents_text) except tokenize.TokenError: # pragma: no cover (bpo-2180) return contents_text for i, token in reversed_enumerate(tokens): if token.offset in visitor.fstrings: # TODO: handle \N escape sequences if r'\N' in token.src: continue paren = i + 3 if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(': continue args, end = parse_call_args(tokens, paren) # if it spans more than one line, bail if tokens[end - 1].line != token.line: continue args_src = tokens_to_src(tokens[paren:end]) if '\\' in args_src or '"' in args_src or "'" in args_src: continue tokens[i] = token._replace( src=_to_fstring(token.src, tokens, args), ) del tokens[i + 1:end] elif token.offset in visitor.named_tuples and token.name == 'NAME': call = visitor.named_tuples[token.offset] types: Dict[str, ast.expr] = { tup.elts[0].s: tup.elts[1] # type: ignore # (checked above) for tup in call.args[1].elts # type: ignore # (checked above) } end, attrs = _typed_class_replacement(tokens, i, call, types) src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' tokens[i:end] = [Token('CODE', src)] elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME': call = visitor.kw_typed_dicts[token.offset] types = { arg.arg: arg.value # type: ignore # (checked above) for arg in call.keywords } end, attrs = _typed_class_replacement(tokens, i, call, types) src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' tokens[i:end] = [Token('CODE', src)] elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME': call = visitor.dict_typed_dicts[token.offset] types = { k.s: v # type: ignore # (checked above) for k, v in zip( call.args[1].keys, # type: ignore # (checked above) call.args[1].values, # type: ignore # (checked above) ) } if call.keywords: total = call.keywords[0].value.value # type: ignore # (checked above) # noqa: E501 end, attrs = _typed_class_replacement(tokens, i, call, types) src = ( f'class {tokens[i].src}(' f'{_unparse(call.func)}, total={total}' f'):\n' f'{attrs}' ) tokens[i:end] = [Token('CODE', src)] else: end, attrs = _typed_class_replacement(tokens, i, call, types) src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' tokens[i:end] = [Token('CODE', src)] return tokens_to_src(tokens)
def _fix_py3_plus(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindPy3Plus() visitor.visit(ast_obj) if not any(( visitor.bases_to_remove, visitor.six_b, visitor.six_calls, visitor.six_raises, visitor.six_remove_decorators, visitor.six_simple, visitor.six_type_ctx, visitor.six_with_metaclass, visitor.super_calls, )): return contents_text def _replace(i, mapping, node): new_token = Token('CODE', _get_tmpl(mapping, node)) if isinstance(node, ast.Name): tokens[i] = new_token else: j = i while tokens[j].src != node.attr: j += 1 tokens[i:j + 1] = [new_token] tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if not token.src: continue elif token.offset in visitor.bases_to_remove: _remove_base_class(tokens, i) elif token.offset in visitor.six_type_ctx: _replace(i, SIX_TYPE_CTX_ATTRS, visitor.six_type_ctx[token.offset]) elif token.offset in visitor.six_simple: _replace(i, SIX_SIMPLE_ATTRS, visitor.six_simple[token.offset]) elif token.offset in visitor.six_remove_decorators: if tokens[i - 1].src == '@': end = i + 1 while tokens[end].name != 'NEWLINE': end += 1 del tokens[i - 1:end + 1] elif token.offset in visitor.six_b: j = _find_open_paren(tokens, i) if (tokens[j + 1].name == 'STRING' and _is_ascii(tokens[j + 1].src) and tokens[j + 2].src == ')'): func_args, end = _parse_call_args(tokens, j) _replace_call(tokens, i, end, func_args, SIX_B_TMPL) elif token.offset in visitor.six_calls: j = _find_open_paren(tokens, i) func_args, end = _parse_call_args(tokens, j) node = visitor.six_calls[token.offset] template = _get_tmpl(SIX_CALLS, node.func) _replace_call(tokens, i, end, func_args, template) elif token.offset in visitor.six_raises: j = _find_open_paren(tokens, i) func_args, end = _parse_call_args(tokens, j) node = visitor.six_raises[token.offset] template = _get_tmpl(SIX_RAISES, node.func) _replace_call(tokens, i, end, func_args, template) elif token.offset in visitor.six_with_metaclass: j = _find_open_paren(tokens, i) func_args, end = _parse_call_args(tokens, j) if len(func_args) == 1: tmpl = WITH_METACLASS_NO_BASES_TMPL else: tmpl = WITH_METACLASS_BASES_TMPL _replace_call(tokens, i, end, func_args, tmpl) elif token.offset in visitor.super_calls: i = _find_open_paren(tokens, i) call = visitor.super_calls[token.offset] victims = _victims(tokens, i, call, gen=False) del tokens[victims.starts[0] + 1:victims.ends[-1]] return tokens_to_src(tokens)
def _fix_new_style_classes(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindNewStyleClasses() visitor.visit(ast_obj) tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): base = visitor.found.get(token.offset) if not base: continue # single base, look forward until the colon to find the ), then look # backward to find the matching ( if (len(base.node.bases) == 1 and not getattr(base.node, 'keywords', None)): j = i while tokens[j].src != ':': j += 1 while tokens[j].src != ')': j -= 1 end_index = j brace_stack = [')'] while brace_stack: j -= 1 if tokens[j].src == ')': brace_stack.append(')') elif tokens[j].src == '(': brace_stack.pop() start_index = j del tokens[start_index:end_index + 1] # multiple bases, look forward and remove a comma elif base.index == 0: j = i brace_stack = [] while tokens[j].src != ',': if tokens[j].src == ')': brace_stack.append(')') j += 1 end_index = j j = i while brace_stack: j -= 1 if tokens[j].src == '(': brace_stack.pop() start_index = j # if there's space afterwards remove that too if tokens[end_index + 1].name == UNIMPORTANT_WS: end_index += 1 # if it is on its own line, remove it if (tokens[start_index - 1].name == UNIMPORTANT_WS and tokens[start_index - 2].name == 'NL' and tokens[end_index + 1].name == 'NL'): start_index -= 1 end_index += 1 del tokens[start_index:end_index + 1] # multiple bases, look backward and remove a comma else: j = i brace_stack = [] while tokens[j].src != ',': if tokens[j].src == '(': brace_stack.append('(') j -= 1 start_index = j j = i while brace_stack: j += 1 if tokens[j].src == ')': brace_stack.pop() end_index = j del tokens[start_index:end_index + 1] return tokens_to_src(tokens)
def _fix_six(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSixUsage() visitor.visit(ast_obj) def _replace_name(i, mapping, node): tokens[i] = Token('CODE', mapping[node.id]) def _replace_attr(i, mapping, node): if tokens[i + 1].src == '.' and tokens[i + 2].src == node.attr: tokens[i:i + 3] = [Token('CODE', mapping[node.attr])] tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.type_ctx_names: node = visitor.type_ctx_names[token.offset] _replace_name(i, SIX_TYPE_CTX_ATTRS, node) elif token.offset in visitor.type_ctx_attrs: node = visitor.type_ctx_attrs[token.offset] _replace_attr(i, SIX_TYPE_CTX_ATTRS, node) elif token.offset in visitor.simple_names: node = visitor.simple_names[token.offset] _replace_name(i, SIX_SIMPLE_ATTRS, node) elif token.offset in visitor.simple_attrs: node = visitor.simple_attrs[token.offset] _replace_attr(i, SIX_SIMPLE_ATTRS, node) elif token.offset in visitor.remove_decorators: if tokens[i - 1].src == '@': end = i + 1 while tokens[end].name != 'NEWLINE': end += 1 del tokens[i - 1:end + 1] elif token.offset in visitor.call_names: node = visitor.call_names[token.offset] if tokens[i + 1].src == '(': func_args, end = _parse_call_args(tokens, i + 1) template = SIX_CALLS[node.func.id] _replace_call(tokens, i, end, func_args, template) elif token.offset in visitor.call_attrs: node = visitor.call_attrs[token.offset] if (tokens[i + 1].src == '.' and tokens[i + 2].src == node.func.attr and tokens[i + 3].src == '('): func_args, end = _parse_call_args(tokens, i + 3) template = SIX_CALLS[node.func.attr] _replace_call(tokens, i, end, func_args, template) elif token.offset in visitor.raise_names: node = visitor.raise_names[token.offset] if tokens[i + 1].src == '(': func_args, end = _parse_call_args(tokens, i + 1) template = SIX_RAISES[node.func.id] _replace_call(tokens, i, end, func_args, template) elif token.offset in visitor.raise_attrs: node = visitor.raise_attrs[token.offset] if (tokens[i + 1].src == '.' and tokens[i + 2].src == node.func.attr and tokens[i + 3].src == '('): func_args, end = _parse_call_args(tokens, i + 3) template = SIX_RAISES[node.func.attr] _replace_call(tokens, i, end, func_args, template) return tokens_to_src(tokens)