def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are allowed. Example: >>> import black >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode())) def f(arg: str = "") -> None: ... A more complex example: >>> print( ... black.format_str( ... "def f(arg:str='')->None: hey", ... mode=black.Mode( ... target_versions={black.TargetVersion.PY36}, ... line_length=10, ... string_normalization=False, ... is_pyi=False, ... ), ... ), ... ) def f( arg: str = '', ) -> None: hey """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] future_imports = get_future_imports(src_node) if mode.target_versions: versions = mode.target_versions else: versions = detect_target_versions(src_node) normalize_fmt_off(src_node) lines = LineGenerator( mode=mode, remove_u_prefix="unicode_literals" in future_imports or supports_feature(versions, Feature.UNICODE_LITERALS), ) elt = EmptyLineTracker(is_pyi=mode.is_pyi) empty_line = Line(mode=mode) after = 0 split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } for current_line in lines.visit(src_node): dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) dst_contents.append(str(empty_line) * before) for line in transform_line( current_line, mode=mode, features=split_line_features ): dst_contents.append(str(line)) return "".join(dst_contents)
def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] future_imports = get_future_imports(src_node) if mode.target_versions: versions = mode.target_versions else: versions = detect_target_versions(src_node, future_imports=future_imports) normalize_fmt_off(src_node, preview=mode.preview) lines = LineGenerator(mode=mode) elt = EmptyLineTracker(is_pyi=mode.is_pyi) empty_line = Line(mode=mode) after = 0 split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } for current_line in lines.visit(src_node): dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) dst_contents.append(str(empty_line) * before) for line in transform_line(current_line, mode=mode, features=split_line_features): dst_contents.append(str(line)) return "".join(dst_contents)
def show(cls, code: Union[str, Leaf, Node]) -> None: """Pretty-print the lib2to3 AST of a given string of `code`. Convenience method for debugging. """ v: DebugVisitor[None] = DebugVisitor() if isinstance(code, str): code = lib2to3_parse(code) list(v.visit(code))
def shed( source_code: str, *, refactor: bool = False, first_party_imports: FrozenSet[str] = frozenset(), min_version: Tuple[int, int] = _default_min_version, _location: str = "string passed to shed.shed()", _remove_unused_imports: bool = True, ) -> str: """Process the source code of a single module.""" assert isinstance(source_code, str) assert isinstance(refactor, bool) assert isinstance(first_party_imports, frozenset) assert all(isinstance(name, str) for name in first_party_imports) assert all(name.isidentifier() for name in first_party_imports) assert min_version in _version_map.values() if source_code == "": return "" # Use black to autodetect our target versions try: parsed = lib2to3_parse( source_code.lstrip(), target_versions={ k for k, v in _version_map.items() if v >= min_version }, ) # black.InvalidInput, blib2to3.pgen2.tokenize.TokenError, SyntaxError... # for forwards-compatibility I'm just going general here. except Exception as err: msg = f"Could not parse {_location}" for pattern, blocktype in _SUGGESTIONS: if re.search(pattern, source_code, flags=re.MULTILINE): msg += f"\n Perhaps you should use a {blocktype!r} block instead?" try: compile(source_code, "<string>", "exec") except SyntaxError: pass else: msg += "\n The syntax is valid Python, so please report this as a bug." w = ShedSyntaxWarning(msg) w.__cause__ = err warnings.warn(w, stacklevel=_location.count(" block in ") + 2) # Even if the code itself has invalid syntax, we might be able to # regex-match and therefore reformat code embedded in docstrings. return docshed( source_code, refactor=refactor, first_party_imports=first_party_imports, min_version=min_version, _location=_location, ) target_versions = set(_version_map) & set( black.detect_target_versions(parsed)) assert target_versions min_version = max( min_version, _version_map[min(target_versions, key=attrgetter("value"))], ) if refactor: # Here we have a deferred imports section, which is pretty ugly. # It does however have one crucial advantage: several hundred milliseconds # of startup latency in the common case where --refactor was *not* passed. # This is a big deal for interactive use-cases such as pre-commit hooks # or format-on-save in editors (though I prefer Black for the latter). global com2ann global _run_codemods if com2ann is None: from ._codemods import _run_codemods # type: ignore try: from com2ann import com2ann except ImportError: # pragma: no cover # on Python 3.8 assert sys.version_info < (3, 8) com2ann = _fallback # OK, everything's imported, back to the runtime logic! # Some tools assume that the file is multi-line, but empty files are valid input. source_code += "\n" # Use com2ann to comvert type comments to annotations on Python 3.8+ annotated = com2ann( source_code, drop_ellipsis=True, silent=True, python_minor_version=min(min_version[1], sys.version_info[1]), ) if annotated: # pragma: no branch # This can only be None if ast.parse() raises a SyntaxError, # which is possible but rare after the parsing checks above. source_code, _ = annotated # One tricky thing: running `isort` or `autoflake` can "unlock" further fixes # for `black`, e.g. "pass;#" -> "pass\n#\n" -> "#\n". We therefore run it # before other fixers, and then (if they made changes) again afterwards. black_mode = black.Mode(target_versions=target_versions) # type: ignore source_code = blackened = black.format_str(source_code, mode=black_mode) pyupgrade_min = min(min_version, max(pyupgrade._main.IMPORT_REMOVALS)) pu_settings = pyupgrade._main.Settings(min_version=pyupgrade_min) source_code = pyupgrade._main._fix_plugins(source_code, settings=pu_settings) if source_code != blackened: # Second step to converge: https://github.com/asottile/pyupgrade/issues/273 source_code = pyupgrade._main._fix_plugins(source_code, settings=pu_settings) source_code = pyupgrade._main._fix_tokens(source_code, min_version=pyupgrade_min) if refactor: source_code = _run_codemods(source_code, min_version=min_version) try: source_code = isort.code( source_code, known_first_party=first_party_imports, known_local_folder={"tests"}, profile="black", combine_as_imports=True, ) except FileSkipComment: pass source_code = autoflake.fix_code( source_code, expand_star_imports=True, remove_all_unused_imports=_remove_unused_imports, ) if source_code != blackened: source_code = black.format_str(source_code, mode=black_mode) # Then shed.docshed (below) formats any code blocks in documentation source_code = docshed( source_code, refactor=refactor, first_party_imports=first_party_imports, min_version=min_version, _location=_location, ) # Remove any extra trailing whitespace return source_code.rstrip() + "\n"