Ejemplo n.º 1
0
def format_str(src_contents: str, *, mode: Mode) -> FileContent:
    """Reformat a string and return new contents.

    `mode` determines formatting options, such as how many characters per line are
    allowed.  Example:

    >>> import black
    >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
    def f(arg: str = "") -> None:
        ...

    A more complex example:

    >>> print(
    ...   black.format_str(
    ...     "def f(arg:str='')->None: hey",
    ...     mode=black.Mode(
    ...       target_versions={black.TargetVersion.PY36},
    ...       line_length=10,
    ...       string_normalization=False,
    ...       is_pyi=False,
    ...     ),
    ...   ),
    ... )
    def f(
        arg: str = '',
    ) -> None:
        hey

    """
    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
    dst_contents = []
    future_imports = get_future_imports(src_node)
    if mode.target_versions:
        versions = mode.target_versions
    else:
        versions = detect_target_versions(src_node)
    normalize_fmt_off(src_node)
    lines = LineGenerator(
        mode=mode,
        remove_u_prefix="unicode_literals" in future_imports
        or supports_feature(versions, Feature.UNICODE_LITERALS),
    )
    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
    empty_line = Line(mode=mode)
    after = 0
    split_line_features = {
        feature
        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
        if supports_feature(versions, feature)
    }
    for current_line in lines.visit(src_node):
        dst_contents.append(str(empty_line) * after)
        before, after = elt.maybe_empty_lines(current_line)
        dst_contents.append(str(empty_line) * before)
        for line in transform_line(
            current_line, mode=mode, features=split_line_features
        ):
            dst_contents.append(str(line))
    return "".join(dst_contents)
Ejemplo n.º 2
0
def _format_str_once(src_contents: str, *, mode: Mode) -> str:
    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
    dst_contents = []
    future_imports = get_future_imports(src_node)
    if mode.target_versions:
        versions = mode.target_versions
    else:
        versions = detect_target_versions(src_node,
                                          future_imports=future_imports)

    normalize_fmt_off(src_node, preview=mode.preview)
    lines = LineGenerator(mode=mode)
    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
    empty_line = Line(mode=mode)
    after = 0
    split_line_features = {
        feature
        for feature in
        {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
        if supports_feature(versions, feature)
    }
    for current_line in lines.visit(src_node):
        dst_contents.append(str(empty_line) * after)
        before, after = elt.maybe_empty_lines(current_line)
        dst_contents.append(str(empty_line) * before)
        for line in transform_line(current_line,
                                   mode=mode,
                                   features=split_line_features):
            dst_contents.append(str(line))
    return "".join(dst_contents)
Ejemplo n.º 3
0
    def show(cls, code: Union[str, Leaf, Node]) -> None:
        """Pretty-print the lib2to3 AST of a given string of `code`.

        Convenience method for debugging.
        """
        v: DebugVisitor[None] = DebugVisitor()
        if isinstance(code, str):
            code = lib2to3_parse(code)
        list(v.visit(code))
Ejemplo n.º 4
0
def shed(
    source_code: str,
    *,
    refactor: bool = False,
    first_party_imports: FrozenSet[str] = frozenset(),
    min_version: Tuple[int, int] = _default_min_version,
    _location: str = "string passed to shed.shed()",
    _remove_unused_imports: bool = True,
) -> str:
    """Process the source code of a single module."""
    assert isinstance(source_code, str)
    assert isinstance(refactor, bool)
    assert isinstance(first_party_imports, frozenset)
    assert all(isinstance(name, str) for name in first_party_imports)
    assert all(name.isidentifier() for name in first_party_imports)
    assert min_version in _version_map.values()

    if source_code == "":
        return ""

    # Use black to autodetect our target versions
    try:
        parsed = lib2to3_parse(
            source_code.lstrip(),
            target_versions={
                k
                for k, v in _version_map.items() if v >= min_version
            },
        )
        # black.InvalidInput, blib2to3.pgen2.tokenize.TokenError, SyntaxError...
        # for forwards-compatibility I'm just going general here.
    except Exception as err:
        msg = f"Could not parse {_location}"
        for pattern, blocktype in _SUGGESTIONS:
            if re.search(pattern, source_code, flags=re.MULTILINE):
                msg += f"\n    Perhaps you should use a {blocktype!r} block instead?"
        try:
            compile(source_code, "<string>", "exec")
        except SyntaxError:
            pass
        else:
            msg += "\n    The syntax is valid Python, so please report this as a bug."
        w = ShedSyntaxWarning(msg)
        w.__cause__ = err
        warnings.warn(w, stacklevel=_location.count(" block in ") + 2)
        # Even if the code itself has invalid syntax, we might be able to
        # regex-match and therefore reformat code embedded in docstrings.
        return docshed(
            source_code,
            refactor=refactor,
            first_party_imports=first_party_imports,
            min_version=min_version,
            _location=_location,
        )
    target_versions = set(_version_map) & set(
        black.detect_target_versions(parsed))
    assert target_versions
    min_version = max(
        min_version,
        _version_map[min(target_versions, key=attrgetter("value"))],
    )

    if refactor:
        # Here we have a deferred imports section, which is pretty ugly.
        # It does however have one crucial advantage: several hundred milliseconds
        # of startup latency in the common case where --refactor was *not* passed.
        # This is a big deal for interactive use-cases such as pre-commit hooks
        # or format-on-save in editors (though I prefer Black for the latter).
        global com2ann
        global _run_codemods
        if com2ann is None:
            from ._codemods import _run_codemods  # type: ignore

            try:
                from com2ann import com2ann
            except ImportError:  # pragma: no cover  # on Python 3.8
                assert sys.version_info < (3, 8)
                com2ann = _fallback
            # OK, everything's imported, back to the runtime logic!

        # Some tools assume that the file is multi-line, but empty files are valid input.
        source_code += "\n"
        # Use com2ann to comvert type comments to annotations on Python 3.8+
        annotated = com2ann(
            source_code,
            drop_ellipsis=True,
            silent=True,
            python_minor_version=min(min_version[1], sys.version_info[1]),
        )
        if annotated:  # pragma: no branch
            # This can only be None if ast.parse() raises a SyntaxError,
            # which is possible but rare after the parsing checks above.
            source_code, _ = annotated

    # One tricky thing: running `isort` or `autoflake` can "unlock" further fixes
    # for `black`, e.g. "pass;#" -> "pass\n#\n" -> "#\n".  We therefore run it
    # before other fixers, and then (if they made changes) again afterwards.
    black_mode = black.Mode(target_versions=target_versions)  # type: ignore
    source_code = blackened = black.format_str(source_code, mode=black_mode)

    pyupgrade_min = min(min_version, max(pyupgrade._main.IMPORT_REMOVALS))
    pu_settings = pyupgrade._main.Settings(min_version=pyupgrade_min)
    source_code = pyupgrade._main._fix_plugins(source_code,
                                               settings=pu_settings)
    if source_code != blackened:
        # Second step to converge: https://github.com/asottile/pyupgrade/issues/273
        source_code = pyupgrade._main._fix_plugins(source_code,
                                                   settings=pu_settings)
    source_code = pyupgrade._main._fix_tokens(source_code,
                                              min_version=pyupgrade_min)

    if refactor:
        source_code = _run_codemods(source_code, min_version=min_version)

    try:
        source_code = isort.code(
            source_code,
            known_first_party=first_party_imports,
            known_local_folder={"tests"},
            profile="black",
            combine_as_imports=True,
        )
    except FileSkipComment:
        pass

    source_code = autoflake.fix_code(
        source_code,
        expand_star_imports=True,
        remove_all_unused_imports=_remove_unused_imports,
    )

    if source_code != blackened:
        source_code = black.format_str(source_code, mode=black_mode)

    # Then shed.docshed (below) formats any code blocks in documentation
    source_code = docshed(
        source_code,
        refactor=refactor,
        first_party_imports=first_party_imports,
        min_version=min_version,
        _location=_location,
    )
    # Remove any extra trailing whitespace
    return source_code.rstrip() + "\n"