def _fix_dict_comp( i: int, tokens: List[Token], arg: Union[ast.ListComp, ast.GeneratorExp], ) -> None: if not immediately_paren('dict', tokens, i): return dict_victims = victims(tokens, i + 1, arg, gen=True) elt_victims = victims(tokens, dict_victims.arg_index, arg.elt, gen=True) del dict_victims.starts[0] end_index = dict_victims.ends.pop() tokens[end_index] = Token('OP', '}') for index in reversed(dict_victims.ends): remove_brace(tokens, index) # See #6, Fix SyntaxError from rewriting dict((a, b)for a, b in y) if tokens[elt_victims.ends[-1] + 1].src == 'for': tokens.insert(elt_victims.ends[-1] + 1, Token(UNIMPORTANT_WS, ' ')) for index in reversed(elt_victims.ends): remove_brace(tokens, index) assert elt_victims.first_comma_index is not None tokens[elt_victims.first_comma_index] = Token('OP', ':') for index in reversed(dict_victims.starts + elt_victims.starts): remove_brace(tokens, index) tokens[i:i + 2] = [Token('OP', '{')]
def _fix_six(contents_text): try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text visitor = FindSixUsage() visitor.visit(ast_obj) tokens = src_to_tokens(contents_text) for i, token in reversed_enumerate(tokens): if token.offset in visitor.simple_names: node = visitor.simple_names[token.offset] tokens[i] = Token('CODE', SIX_SIMPLE_ATTRS[node.id]) elif token.offset in visitor.simple_attrs: node = visitor.simple_attrs[token.offset] if tokens[i + 1].src == '.' and tokens[i + 2].src == node.attr: tokens[i:i + 3] = [Token('CODE', SIX_SIMPLE_ATTRS[node.attr])] elif token.offset in visitor.remove_decorators: if tokens[i - 1].src == '@': end = i + 1 while tokens[end].name != 'NEWLINE': end += 1 del tokens[i - 1:end + 1] return tokens_to_src(tokens)
def _fix_add_metaclass(i: int, tokens: List[Token]) -> None: j = find_open_paren(tokens, i) func_args, end = parse_call_args(tokens, j) metaclass = f'metaclass={arg_str(tokens, *func_args[0])}' # insert `metaclass={args[0]}` into `class:` # search forward for the `class` token j = i + 1 while tokens[j].src != 'class': j += 1 class_token = j # then search forward for a `:` token, not inside a brace j = find_block_start(tokens, j) last_paren = -1 for k in range(class_token, j): if tokens[k].src == ')': last_paren = k if last_paren == -1: tokens.insert(j, Token('CODE', f'({metaclass})')) else: insert = last_paren - 1 while tokens[insert].name in NON_CODING_TOKENS: insert -= 1 if tokens[insert].src == '(': # no bases src = metaclass elif tokens[insert].src != ',': src = f', {metaclass}' else: src = f' {metaclass},' tokens.insert(insert + 1, Token('CODE', src)) remove_decorator(i, tokens)
def test_src_to_tokens_octal_literal_normalization(): ret = src_to_tokens('0755\n') assert ret == [ Token('NUMBER', '0755', line=1, utf8_byte_offset=0), Token('NEWLINE', '\n', line=1, utf8_byte_offset=4), Token('ENDMARKER', '', line=2, utf8_byte_offset=0), ]
def test_src_to_tokens_string_prefix_normalization(prefix): src = f"{prefix}'foo'\n" ret = src_to_tokens(src) assert ret == [ Token('STRING', f"{prefix}'foo'", line=1, utf8_byte_offset=0), Token('NEWLINE', '\n', line=1, utf8_byte_offset=5 + len(prefix)), Token('ENDMARKER', '', line=2, utf8_byte_offset=0), ]
def test_src_to_tokens_long_literal_normalization(postfix): src = f'123{postfix}\n' ret = src_to_tokens(src) assert ret == [ Token('NUMBER', f'123{postfix}', line=1, utf8_byte_offset=0), Token('NEWLINE', '\n', line=1, utf8_byte_offset=4), Token('ENDMARKER', '', line=2, utf8_byte_offset=0), ]
def _fix_percent_format_dict( i: int, tokens: List[Token], *, node_right: ast.Dict, ) -> None: # TODO: handle \N escape sequences if r'\N' in tokens[i].src: return seen_keys: Set[str] = set() keys = {} for k in node_right.keys: # not a string key if not isinstance(k, ast.Str): return # duplicate key elif k.s in seen_keys: return # not an identifier elif not k.s.isidentifier(): return # a keyword elif k.s in KEYWORDS: return seen_keys.add(k.s) keys[ast_to_offset(k)] = k # TODO: this is overly timid brace = i + 4 if tokens_to_src(tokens[i + 1:brace + 1]) != ' % {': return fmt_victims = victims(tokens, brace, node_right, gen=False) brace_end = fmt_victims.ends[-1] key_indices = [] for j, token in enumerate(tokens[brace:brace_end], brace): key = keys.pop(token.offset, None) if key is None: continue # we found the key, but the string didn't match (implicit join?) elif ast.literal_eval(token.src) != key.s: return # the map uses some strange syntax that's not `'key': value` elif tokens[j + 1].src != ':' or tokens[j + 2].src != ' ': return else: key_indices.append((j, key.s)) assert not keys, keys tokens[brace_end] = tokens[brace_end]._replace(src=')') for key_index, s in reversed(key_indices): tokens[key_index:key_index + 3] = [Token('CODE', f'{s}=')] newsrc = _percent_to_format(tokens[i].src) tokens[i] = tokens[i]._replace(src=newsrc) tokens[i + 1:brace + 1] = [Token('CODE', '.format'), Token('OP', '(')]
def _delete_list_comp_brackets(i: int, tokens: List[Token]) -> None: start = find_comprehension_opening_bracket(i, tokens) end = find_closing_bracket(tokens, start) tokens[end] = Token('PLACEHOLDER', '') tokens[start] = Token('PLACEHOLDER', '') j = end + 1 while j < len(tokens) and tokens[j].name in NON_CODING_TOKENS: j += 1 if tokens[j].name == 'OP' and tokens[j].src == ',': tokens[j] = Token('PLACEHOLDER', '')
def _fix_optional(i: int, tokens: List[Token]) -> None: j = find_token(tokens, i, '[') k = find_closing_bracket(tokens, j) if tokens[j].line == tokens[k].line: tokens[k] = Token('CODE', ' | None') del tokens[i:j + 1] else: tokens[j] = tokens[j]._replace(src='(') tokens[k] = tokens[k]._replace(src=')') tokens[i:j] = [Token('CODE', 'None | ')]
def _process_is_literal(tokens, i, compare): while tokens[i].src != 'is': i -= 1 if isinstance(compare, ast.Is): tokens[i] = tokens[i]._replace(src='==') else: tokens[i] = tokens[i]._replace(src='!=') # since we iterate backward, the dummy tokens keep the same length i += 1 while tokens[i].src != 'not': tokens[i] = Token('DUMMY', '') i += 1 tokens[i] = Token('DUMMY', '')
def _process_set_literal(tokens, start, arg): if _is_wtf('set', tokens, start): return set_victims = _get_victims(tokens, start + 1, arg) del set_victims.starts[0] end_index = set_victims.ends.pop() tokens[end_index] = Token('OP', '}') for index in reversed(set_victims.starts + set_victims.ends): _remove_brace(tokens, index) tokens[start:start + 2] = [Token('OP', '{')]
def _fix_union( i: int, tokens: List[Token], *, arg: ast.expr, arg_count: int, ) -> None: arg_offset = ast_to_offset(arg) j = find_token(tokens, i, '[') to_delete = [] commas: List[int] = [] arg_depth = -1 depth = 1 k = j + 1 while depth: if tokens[k].src in OPENING: if arg_depth == -1: to_delete.append(k) depth += 1 elif tokens[k].src in CLOSING: depth -= 1 if 0 < depth < arg_depth: to_delete.append(k) elif tokens[k].offset == arg_offset: arg_depth = depth elif depth == arg_depth and tokens[k].src == ',': if len(commas) >= arg_count - 1: to_delete.append(k) else: commas.append(k) k += 1 k -= 1 if tokens[j].line == tokens[k].line: del tokens[k] for comma in commas: tokens[comma] = Token('CODE', ' |') for paren in reversed(to_delete): del tokens[paren] del tokens[i:j + 1] else: tokens[j] = tokens[j]._replace(src='(') tokens[k] = tokens[k]._replace(src=')') for comma in commas: tokens[comma] = Token('CODE', ' |') for paren in reversed(to_delete): del tokens[paren] del tokens[i:j]
def _replace_dict_brackets(i: int, tokens: List[Token]) -> None: closing = find_closing_bracket(tokens, i) j = closing - 1 while tokens[j].name in NON_CODING_TOKENS and j > i: j -= 1 if tokens[j].name == 'OP' and tokens[j].src == ',': tokens[j] = Token('PLACEHOLDER', '') if tokens[i].line == tokens[closing].line: tokens[i] = Token('PLACEHOLDER', '') tokens[closing] = Token('PLACEHOLDER', '') else: tokens[i] = Token('CODE', '(') tokens[closing] = Token('CODE', ')')
def _fix_set_literal(i: int, tokens: List[Token], *, arg: ast.expr) -> None: # TODO: this could be implemented with a little extra logic if not immediately_paren('set', tokens, i): return gen = isinstance(arg, ast.GeneratorExp) set_victims = victims(tokens, i + 1, arg, gen=gen) del set_victims.starts[0] end_index = set_victims.ends.pop() tokens[end_index] = Token('OP', '}') for index in reversed(set_victims.starts + set_victims.ends): remove_brace(tokens, index) tokens[i:i + 2] = [Token('OP', '{')]
def _fix_open_mode(i: int, tokens: List[Token]) -> None: j = find_open_paren(tokens, i) func_args, end = parse_call_args(tokens, j) mode = tokens_to_src(tokens[slice(*func_args[1])]) mode_stripped = mode.strip().strip('"\'') if mode_stripped in U_MODE_REMOVE: del tokens[func_args[0][1]:func_args[1][1]] elif mode_stripped in U_MODE_REPLACE_R: new_mode = mode.replace('U', 'r') tokens[slice(*func_args[1])] = [Token('SRC', new_mode)] elif mode_stripped in U_MODE_REMOVE_U: new_mode = mode.replace('U', '') tokens[slice(*func_args[1])] = [Token('SRC', new_mode)] else: raise AssertionError(f'unreachable: {mode!r}')
def _fix_percent_format_tuple(tokens, start, node): # TODO: this is overly timid paren = start + 4 if tokens_to_src(tokens[start + 1:paren + 1]) != ' % (': return victims = _victims(tokens, paren, node.right, gen=False) victims.ends.pop() for index in reversed(victims.starts + victims.ends): _remove_brace(tokens, index) newsrc = _percent_to_format(tokens[start].src) tokens[start] = tokens[start]._replace(src=newsrc) tokens[start + 1:paren] = [Token('Format', '.format'), Token('OP', '(')]
def _fix_percent_format_dict(tokens, start, node): seen_keys = set() keys = {} for k in node.right.keys: # not a string key if not isinstance(k, ast.Str): return # duplicate key elif k.s in seen_keys: return # not an identifier elif not IDENT_RE.match(k.s): return # a keyword elif k.s in keyword.kwlist: return seen_keys.add(k.s) keys[_ast_to_offset(k)] = k # TODO: this is overly timid brace = start + 4 if tokens_to_src(tokens[start + 1:brace + 1]) != ' % {': return victims = _victims(tokens, brace, node.right, gen=False) brace_end = victims.ends[-1] key_indices = [] for i, token in enumerate(tokens[brace:brace_end], brace): k = keys.pop(token.offset, None) if k is None: continue # we found the key, but the string didn't match (implicit join?) elif ast.literal_eval(token.src) != k.s: return # the map uses some strange syntax that's not `'k': v` elif tokens_to_src(tokens[i + 1:i + 3]) != ': ': return else: key_indices.append((i, k.s)) assert not keys, keys tokens[brace_end] = tokens[brace_end]._replace(src=')') for (key_index, s) in reversed(key_indices): tokens[key_index:key_index + 3] = [Token('CODE', '{}='.format(s))] newsrc = _percent_to_format(tokens[start].src) tokens[start] = tokens[start]._replace(src=newsrc) tokens[start + 1:brace + 1] = [Token('CODE', '.format'), Token('OP', '(')]
def replace_call( tokens: List[Token], start: int, end: int, args: List[Tuple[int, int]], tmpl: str, *, parens: Sequence[int] = (), ) -> None: arg_strs = [arg_str(tokens, *arg) for arg in args] for paren in parens: arg_strs[paren] = f'({arg_strs[paren]})' start_rest = args[0][1] + 1 while (start_rest < end and tokens[start_rest].name in {'COMMENT', UNIMPORTANT_WS}): start_rest += 1 # Remove trailing comma end_rest = end - 1 while (tokens[end_rest - 1].name == 'OP' and tokens[end_rest - 1].src == ','): end_rest -= 1 rest = tokens_to_src(tokens[start_rest:end_rest]) src = tmpl.format(args=arg_strs, rest=rest) tokens[start:end] = [Token('CODE', src)]
def _remove_u_prefix(token: Token) -> Token: prefix, rest = parse_string_literal(token.src) if 'u' not in prefix.lower(): return token else: new_prefix = prefix.replace('u', '').replace('U', '') return token._replace(src=new_prefix + rest)
def _remove_u_prefix(token): prefix, rest = _parse_string_literal(token.src) if 'u' not in prefix.lower(): return token else: new_prefix = prefix.replace('u', '').replace('U', '') return Token('STRING', new_prefix + rest)
def _fix_format_literals(contents_text): tokens = src_to_tokens(contents_text) to_replace = [] string_start = None string_end = None seen_dot = False for i, token in enumerate(tokens): if string_start is None and token.name == 'STRING': string_start = i string_end = i + 1 elif string_start is not None and token.name == 'STRING': string_end = i + 1 elif string_start is not None and token.src == '.': seen_dot = True elif seen_dot and token.src == 'format': to_replace.append((string_start, string_end)) string_start, string_end, seen_dot = None, None, False elif token.name not in NON_CODING_TOKENS: string_start, string_end, seen_dot = None, None, False for start, end in reversed(to_replace): src = tokens_to_src(tokens[start:end]) new_src = _rewrite_string_literal(src) tokens[start:end] = [Token('STRING', new_src)] return tokens_to_src(tokens)
def remove_providing_args(tokens: list[Token], i: int, *, node: ast.Call) -> None: j = find(tokens, i, name=OP, src="(") func_args, _ = parse_call_args(tokens, j) if len(node.args): start_idx, end_idx = func_args[0] if len(node.args) == 1: del tokens[start_idx:end_idx] else: # Have to replace with None tokens[start_idx:end_idx] = [Token(name=CODE, src="None")] else: for n, keyword in enumerate(node.keywords): if keyword.arg == "providing_args": start_idx, end_idx = func_args[n] start_idx = reverse_consume(tokens, start_idx, name=UNIMPORTANT_WS) start_idx = reverse_consume(tokens, start_idx, name=INDENT) if n > 0: start_idx = reverse_consume(tokens, start_idx, name=OP, src=",") if n < len(node.keywords) - 1: end_idx = consume(tokens, end_idx, name=UNIMPORTANT_WS) end_idx = consume(tokens, end_idx, name=OP, src=",") end_idx = consume(tokens, end_idx, name=UNIMPORTANT_WS) end_idx = consume(tokens, end_idx, name=COMMENT) end_idx += 1 del tokens[start_idx:end_idx]
def _fix_open_mode(i: int, tokens: List[Token], *, arg_idx: int) -> None: j = find_open_paren(tokens, i) func_args, end = parse_call_args(tokens, j) mode = tokens_to_src(tokens[slice(*func_args[arg_idx])]) mode_stripped = mode.split('=')[-1] mode_stripped = ast.literal_eval(mode_stripped.strip()) if mode_stripped in U_MODE_REMOVE: delete_argument(arg_idx, tokens, func_args) elif mode_stripped in U_MODE_REPLACE_R: new_mode = mode.replace('U', 'r') tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)] elif mode_stripped in U_MODE_REMOVE_U: new_mode = mode.replace('U', '') tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)] else: raise AssertionError(f'unreachable: {mode!r}')
def _fix_py3_block(i: int, tokens: List[Token]) -> None: if tokens[i].src == 'if': if_block = Block.find(tokens, i) if_block.dedent(tokens) del tokens[if_block.start:if_block.block] else: if_block = Block.find(tokens, _find_elif(tokens, i)) if_block.replace_condition(tokens, [Token('NAME', 'else')])
def _process_ctx_returned(tokens: List[Token], return_start: int, brace_start: int, brace_end: int) -> int: inserted = 0 offset = brace_start + 1 limit = brace_end - 1 kwargs = tokens[offset:limit] indent = tokens[return_start].utf8_byte_offset for assignment in reversed(list(_split_assign(kwargs))): key = takewhile(lambda token: token.src != "=", assignment) value = dropwhile(lambda token: token.src != "=", assignment) name = list(strip(key, lambda token: token.src.isspace())) variable = list( strip(islice(value, 1, None), lambda token: token.src.isspace())) patch = [ Token(name="NAME", src="ctx"), Token(name="OP", src="."), *name, Token(name="UNIMPORTANT_WS", src=" "), Token(name="OP", src="="), Token(name="UNIMPORTANT_WS", src=" "), *variable, Token(name="NEWLINE", src="\n"), Token(name="INDENT", src=" " * indent), ] tokens[return_start:return_start] = patch inserted += len(patch) return inserted
def test_src_to_tokens_escaped_nl_windows(): src = ('x = \\\r\n' ' 5\r\n') ret = src_to_tokens(src) assert ret == [ Token('NAME', 'x', line=1, utf8_byte_offset=0), Token(UNIMPORTANT_WS, ' ', line=None, utf8_byte_offset=None), Token('OP', '=', line=1, utf8_byte_offset=2), Token(UNIMPORTANT_WS, ' ', line=None, utf8_byte_offset=None), Token(ESCAPED_NL, '\\\r\n', line=None, utf8_byte_offset=None), Token(UNIMPORTANT_WS, ' ', line=None, utf8_byte_offset=None), Token('NUMBER', '5', line=2, utf8_byte_offset=4), Token('NEWLINE', '\r\n', line=2, utf8_byte_offset=5), Token('ENDMARKER', '', line=3, utf8_byte_offset=0), ]
def _fix_is_literal( i: int, tokens: List[Token], *, op: Union[ast.Is, ast.IsNot], ) -> None: while tokens[i].src != 'is': i -= 1 if isinstance(op, ast.Is): tokens[i] = tokens[i]._replace(src='==') else: tokens[i] = tokens[i]._replace(src='!=') # since we iterate backward, the empty tokens keep the same length i += 1 while tokens[i].src != 'not': tokens[i] = Token('EMPTY', '') i += 1 tokens[i] = Token('EMPTY', '')
def _replace(i, mapping, node): new_token = Token('CODE', _get_tmpl(mapping, node)) if isinstance(node, ast.Name): tokens[i] = new_token else: j = i while tokens[j].src != node.attr: j += 1 tokens[i:j + 1] = [new_token]
def _fix_escape_sequences(token: Token) -> Token: prefix, rest = parse_string_literal(token.src) actual_prefix = prefix.lower() if 'r' in actual_prefix or '\\' not in rest: return token is_bytestring = 'b' in actual_prefix def _is_valid_escape(match: Match[str]) -> bool: c = match.group()[1] return ( c in ESCAPE_STARTS or (not is_bytestring and c in 'uU') or ( not is_bytestring and c == 'N' and bool(NAMED_ESCAPE_NAME.match(rest, match.end())) ) ) has_valid_escapes = False has_invalid_escapes = False for match in ESCAPE_RE.finditer(rest): if _is_valid_escape(match): has_valid_escapes = True else: has_invalid_escapes = True def cb(match: Match[str]) -> str: matched = match.group() if _is_valid_escape(match): return matched else: return fr'\{matched}' if has_invalid_escapes and (has_valid_escapes or 'u' in actual_prefix): return token._replace(src=prefix + ESCAPE_RE.sub(cb, rest)) elif has_invalid_escapes and not has_valid_escapes: return token._replace(src=prefix + 'r' + rest) else: return token
def _replace_call(tokens, start, end, args, tmpl): arg_strs = [tokens_to_src(tokens[slice(*arg)]).strip() for arg in args] start_rest = args[0][1] + 1 while (start_rest < end and tokens[start_rest].name in {'COMMENT', UNIMPORTANT_WS}): start_rest += 1 rest = tokens_to_src(tokens[start_rest:end - 1]) src = tmpl.format(args=arg_strs, rest=rest) tokens[start:end] = [Token('CODE', src)]