Exemple #1
0
 def test_get_future_imports(self) -> None:
     node = black.lib2to3_parse("\n")
     self.assertEqual(set(), black.get_future_imports(node))
     node = black.lib2to3_parse("from __future__ import black\n")
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse(
         "from __future__ import multiple, imports\n")
     self.assertEqual({"multiple", "imports"},
                      black.get_future_imports(node))
     node = black.lib2to3_parse(
         "from __future__ import (parenthesized, imports)\n")
     self.assertEqual({"parenthesized", "imports"},
                      black.get_future_imports(node))
     node = black.lib2to3_parse(
         "from __future__ import multiple\nfrom __future__ import imports\n"
     )
     self.assertEqual({"multiple", "imports"},
                      black.get_future_imports(node))
     node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse(
         '"""docstring"""\nfrom __future__ import black\n')
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse(
         "some(other, code)\nfrom __future__ import black\n")
     self.assertEqual(set(), black.get_future_imports(node))
     node = black.lib2to3_parse("from some.module import black\n")
     self.assertEqual(set(), black.get_future_imports(node))
Exemple #2
0
def fix_one(path: Path) -> bool:
    path = path.resolve()
    shadowed = list(shadowed_typing_names(path))
    if shadowed:
        print(
            f"warning: {path} skipped, shadows typing names: {shadowed}",
            file=sys.stderr,
        )
        return False

    v = TypingRewriter()
    with path.open() as f:
        contents = f.read()
        if "import typing\n" not in contents:
            return False

        if "import typing  # NoQA\n" in contents:
            print(
                f"warning: {path} skipped, typing import special-cased.",
                file=sys.stderr,
            )
            return False

        code = lib2to3_parse(contents)
    with path.open("w") as f:
        for chunk in v.visit(code):
            f.write(chunk)
    return True
Exemple #3
0
 def assertFormatEqual(self, expected: str, actual: str) -> None:
     if actual != expected:
         black.out('Expected tree:', fg='green')
         try:
             exp_node = black.lib2to3_parse(expected)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
         black.out('Actual tree:', fg='red')
         try:
             exp_node = black.lib2to3_parse(actual)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
     self.assertEqual(expected, actual)
Exemple #4
0
 def assertFormatEqual(self, expected: str, actual: str) -> None:
     if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
         bdv: black.DebugVisitor[Any]
         black.out("Expected tree:", fg="green")
         try:
             exp_node = black.lib2to3_parse(expected)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
         black.out("Actual tree:", fg="red")
         try:
             exp_node = black.lib2to3_parse(actual)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
     self.assertEqual(expected, actual)
Exemple #5
0
 def assertFormatEqual(self, expected: str, actual: str) -> None:
     if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
         bdv: black.DebugVisitor[Any]
         black.out("Expected tree:", fg="green")
         try:
             exp_node = black.lib2to3_parse(expected)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
         black.out("Actual tree:", fg="red")
         try:
             exp_node = black.lib2to3_parse(actual)
             bdv = black.DebugVisitor()
             list(bdv.visit(exp_node))
         except Exception as ve:
             black.err(str(ve))
     self.assertEqual(expected, actual)
Exemple #6
0
def shed(
    *, source_code: str,
    first_party_imports: FrozenSet[str] = frozenset()) -> str:
    """Process the source code of a single module."""
    assert isinstance(source_code, str)
    assert isinstance(first_party_imports, frozenset)
    assert all(isinstance(name, str) for name in first_party_imports)
    assert all(name.isidentifier() for name in first_party_imports)

    # Use black to autodetect our target versions
    target_versions = {
        v
        for v in black.detect_target_versions(
            black.lib2to3_parse(source_code.lstrip(), set(_version_map)))
        if v.value >= black.TargetVersion.PY36.value
    }
    assert target_versions

    input_code = source_code
    # Autoflake first:
    source_code = autoflake.fix_code(
        source_code,
        expand_star_imports=True,
        remove_all_unused_imports=True,
        remove_duplicate_keys=True,
        remove_unused_variables=True,
    )

    # Then isort...
    # TODO: swap as soon as 5.0 is released for black compat & clean config handling
    # source_code = isort.api.sorted_imports(
    #     file_contents=source_code, known_first_party=first_party_imports,
    # )
    source_code = isort.SortImports(file_contents=source_code).output

    # Now pyupgrade - see pyupgrade._fix_file
    source_code = pyupgrade._fix_tokens(
        source_code,
        min_version=_version_map[min(target_versions,
                                     key=attrgetter("value"))],
    )
    source_code = pyupgrade._fix_percent_format(source_code)
    source_code = pyupgrade._fix_py3_plus(source_code)

    # and finally Black!
    source_code = black.format_str(
        source_code, mode=black.FileMode(target_versions=target_versions))

    if source_code == input_code:
        return source_code
    # If we've modified the code, iterate to a fixpoint.
    # e.g. "pass;#" -> "pass\n#\n" -> "#\n"
    return shed(source_code=source_code,
                first_party_imports=first_party_imports)
Exemple #7
0
def _assert_format_equal(expected: str, actual: str) -> None:
    if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
        bdv: DebugVisitor[Any]
        out("Expected tree:", fg="green")
        try:
            exp_node = black.lib2to3_parse(expected)
            bdv = DebugVisitor()
            list(bdv.visit(exp_node))
        except Exception as ve:
            err(str(ve))
        out("Actual tree:", fg="red")
        try:
            exp_node = black.lib2to3_parse(actual)
            bdv = DebugVisitor()
            list(bdv.visit(exp_node))
        except Exception as ve:
            err(str(ve))

    if actual != expected:
        out(diff(expected, actual, "expected", "actual"))

    assert actual == expected
Exemple #8
0
 def test_get_future_imports(self) -> None:
     node = black.lib2to3_parse("\n")
     self.assertEqual(set(), black.get_future_imports(node))
     node = black.lib2to3_parse("from __future__ import black\n")
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse("from __future__ import multiple, imports\n")
     self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
     node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
     self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
     node = black.lib2to3_parse(
         "from __future__ import multiple\nfrom __future__ import imports\n"
     )
     self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
     node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
     self.assertEqual({"black"}, black.get_future_imports(node))
     node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
     self.assertEqual(set(), black.get_future_imports(node))
     node = black.lib2to3_parse("from some.module import black\n")
     self.assertEqual(set(), black.get_future_imports(node))
Exemple #9
0
def fuzz(buf):
    try:
        src_node = lib2to3_parse(buf.lstrip().decode('utf-8'))
    except Exception:
        # not interested in bad input here
        return

    future_imports = get_future_imports(src_node)

    normalize_strings = True
    lines = LineGenerator(
        remove_u_prefix="unicode_literals" in future_imports,
        normalize_strings=normalize_strings,
    )
    elt = EmptyLineTracker()

    for current_line in lines.visit(src_node):
        before, after = elt.maybe_empty_lines(current_line)
Exemple #10
0
 def test_is_python36(self) -> None:
     node = black.lib2to3_parse("def f(*, arg): ...\n")
     self.assertFalse(black.is_python36(node))
     node = black.lib2to3_parse("def f(*, arg,): ...\n")
     self.assertTrue(black.is_python36(node))
     node = black.lib2to3_parse("def f(*, arg): f'string'\n")
     self.assertTrue(black.is_python36(node))
     source, expected = read_data('function')
     node = black.lib2to3_parse(source)
     self.assertTrue(black.is_python36(node))
     node = black.lib2to3_parse(expected)
     self.assertTrue(black.is_python36(node))
     source, expected = read_data('expression')
     node = black.lib2to3_parse(source)
     self.assertFalse(black.is_python36(node))
     node = black.lib2to3_parse(expected)
     self.assertFalse(black.is_python36(node))
Exemple #11
0
 def test_is_python36(self) -> None:
     node = black.lib2to3_parse("def f(*, arg): ...\n")
     self.assertFalse(black.is_python36(node))
     node = black.lib2to3_parse("def f(*, arg,): ...\n")
     self.assertTrue(black.is_python36(node))
     node = black.lib2to3_parse("def f(*, arg): f'string'\n")
     self.assertTrue(black.is_python36(node))
     source, expected = read_data("function")
     node = black.lib2to3_parse(source)
     self.assertTrue(black.is_python36(node))
     node = black.lib2to3_parse(expected)
     self.assertTrue(black.is_python36(node))
     source, expected = read_data("expression")
     node = black.lib2to3_parse(source)
     self.assertFalse(black.is_python36(node))
     node = black.lib2to3_parse(expected)
     self.assertFalse(black.is_python36(node))
Exemple #12
0
def shed(
        source_code: str,
        *,
        refactor: bool = False,
        first_party_imports: FrozenSet[str] = frozenset(),
) -> str:
    """Process the source code of a single module."""
    assert isinstance(source_code, str)
    assert isinstance(refactor, bool)
    assert isinstance(first_party_imports, frozenset)
    assert all(isinstance(name, str) for name in first_party_imports)
    assert all(name.isidentifier() for name in first_party_imports)

    # Use black to autodetect our target versions
    target_versions = {
        v
        for v in black.detect_target_versions(
            black.lib2to3_parse(source_code.lstrip(), set(_version_map)))
        if v.value >= black.TargetVersion.PY36.value
    }
    assert target_versions
    min_version = _version_map[min(target_versions, key=attrgetter("value"))]

    if refactor:
        # Some tools assume that the file is multi-line, but empty files are valid input.
        source_code += "\n"
        # Use com2ann to comvert type comments to annotations on Python 3.8+
        source_code, _ = com2ann(
            source_code,
            drop_ellipsis=True,
            silent=True,
            python_minor_version=min_version[1],
        )
        # Use teyit to replace old unittest.assertX methods on Python 3.9+
        source_code, _ = _teyit_refactor(source_code)
        # Then apply pybetter's fixes with libcst
        tree = libcst.parse_module(source_code)
        for fixer in _pybetter_fixers:
            tree = fixer(tree)
        source_code = tree.code
    # Then shed.docshed (below) formats any code blocks in documentation
    source_code = docshed(source=source_code,
                          first_party_imports=first_party_imports)
    # And pyupgrade - see pyupgrade._fix_file - is our last stable fixer
    # Calculate separate minver because pyupgrade doesn't have py39-specific logic yet
    pyupgrade_min_ver = min(min_version, max(pyupgrade.IMPORT_REMOVALS.keys()))
    source_code = pyupgrade._fix_tokens(source_code,
                                        min_version=pyupgrade_min_ver)
    source_code = pyupgrade._fix_percent_format(source_code)
    source_code = pyupgrade._fix_py3_plus(source_code,
                                          min_version=pyupgrade_min_ver)
    source_code = pyupgrade._fix_py36_plus(source_code)

    # One tricky thing: running `isort` or `autoflake` can "unlock" further fixes
    # for `black`, e.g. "pass;#" -> "pass\n#\n" -> "#\n".  We therefore loop until
    # neither of them have made a change in the last loop body, trusting that
    # `black` itself is idempotent because that's tested upstream.
    prev = ""
    black_mode = black.FileMode(target_versions=target_versions)
    while prev != source_code:
        prev = source_code = black.format_str(source_code, mode=black_mode)
        source_code = autoflake.fix_code(
            source_code,
            expand_star_imports=True,
            remove_all_unused_imports=True,
            remove_duplicate_keys=True,
            remove_unused_variables=True,
        )
        source_code = isort.code(
            source_code,
            known_first_party=first_party_imports,
            profile="black",
            combine_as_imports=True,
        )

    # Remove any extra trailing whitespace
    return source_code.rstrip() + "\n"
Exemple #13
0
 def test_endmarker(self) -> None:
     n = black.lib2to3_parse("\n")
     self.assertEqual(n.type, black.syms.file_input)
     self.assertEqual(len(n.children), 1)
     self.assertEqual(n.children[0].type, black.token.ENDMARKER)
Exemple #14
0
 def test_endmarker(self) -> None:
     n = black.lib2to3_parse("\n")
     self.assertEqual(n.type, black.syms.file_input)
     self.assertEqual(len(n.children), 1)
     self.assertEqual(n.children[0].type, black.token.ENDMARKER)