Ejemplo n.º 1
0
 def refactor(self, path: Path) -> str:
     source = utils.read(path)[0]
     scanner = Scanner(source=source,
                       include_star_import=self.include_star_import)
     scanner.traverse()
     refactor_result = refactor_string(
         source=scanner.source,
         unused_imports=list(scanner.get_unused_imports()),
     )
     return refactor_result
Ejemplo n.º 2
0
 def __init__(self,
              config_file: Optional[Path] = None,
              *,
              include_star_import: bool = False,
              show_error: bool = False) -> None:
     self.show_error = show_error
     self.config = (Config(config_file) if config_file
                    and config_file.name in CONFIG_FILES else Config)
     self.scanner = Scanner(include_star_import=include_star_import,
                            show_error=self.show_error)
Ejemplo n.º 3
0
 def refactor(self, action: str) -> str:
     scanner = Scanner(
         source=textwrap.dedent(action),
         include_star_import=self.include_star_import,
     )
     scanner.traverse()
     refactor_result = refactor_string(
         source=scanner.source,
         unused_imports=list(scanner.get_unused_imports()),
     )
     return refactor_result
Ejemplo n.º 4
0
class ScannerTestCase(unittest.TestCase):
    maxDiff = None

    def setUp(self):
        self.scanner = Scanner()

    def assertUnimportEqual(
        self,
        source,
        expected_names,
        expected_classes,
        expected_functions,
        expected_imports,
    ):
        self.scanner.run_visit(source)
        self.assertEqual(expected_names, list(self.scanner.get_names()))
        self.assertEqual(expected_classes, self.scanner.classes)
        self.assertEqual(expected_functions, self.scanner.functions)
        self.assertEqual(expected_imports, self.scanner.imports)
Ejemplo n.º 5
0
 def assertSourceAfterScanningEqualToExpected(self,
                                              source,
                                              expected_unused_imports=[]):
     scanner = Scanner(
         source=textwrap.dedent(source),
         include_star_import=self.include_star_import,
     )
     scanner.traverse()
     super().assertEqual(
         expected_unused_imports,
         list(scanner.get_unused_imports()),
     )
     scanner.clear()
Ejemplo n.º 6
0
 def assertUnimportEqual(
     self,
     source,
     expected_names=[],
     expected_imports=[],
 ):
     scanner = Scanner(
         source=textwrap.dedent(source),
         include_star_import=self.include_star_import,
     )
     scanner.traverse()
     self.assertEqual(expected_names, scanner.names)
     self.assertEqual(expected_imports, scanner.imports)
     scanner.clear()
Ejemplo n.º 7
0
 def __init__(self, config_file=None, include_star_import=False):
     self.config = Config(config_file)
     self.scanner = Scanner(include_star_import=include_star_import)
Ejemplo n.º 8
0
class Session:
    GLOB_PATTERN = "**/*.py"
    INCLUDE_REGEX_PATTERN = "\\.(py)$"
    EXCLUDE_REGEX_PATTERN = "^$"

    def __init__(self, config_file=None, include_star_import=False):
        self.config = Config(config_file)
        self.scanner = Scanner(include_star_import=include_star_import)

    def _read(self, path: Path):
        try:
            with tokenize.open(path) as stream:
                source = stream.read()
                encoding = stream.encoding
        except OSError as exc:
            print(f"{exc} Can't read")
            return "", "utf-8"
        except SyntaxError as exc:
            print(f"{exc} Can't read")
            return "", "utf-8"
        return source, encoding

    def _list_paths(self, start, include, exclude):
        include_regex, exclude_regex = (
            re.compile(include or self.INCLUDE_REGEX_PATTERN),
            re.compile(exclude or self.EXCLUDE_REGEX_PATTERN),
        )
        if start.is_dir():
            file_names = start.glob(self.GLOB_PATTERN)
        else:
            file_names = [start]
        for filename in file_names:
            if include_regex.search(
                    str(filename)) and not exclude_regex.search(str(filename)):
                yield filename

    def refactor(self, source: str) -> str:
        self.scanner.run_visit(source)
        refactor = refactor_string(
            source=source,
            unused_imports=list(self.scanner.get_unused_imports()),
        )
        self.scanner.clear()
        return refactor

    def refactor_file(self, path: Path, apply: bool = False):
        source, encoding = self._read(path)
        result = self.refactor(source)
        if apply:
            path.write_text(result, encoding=encoding)
        return result

    def diff(self, source: str) -> tuple:
        return tuple(
            difflib.unified_diff(source.splitlines(),
                                 self.refactor(source).splitlines()))

    def diff_file(self, path: Path) -> tuple:
        source, _ = self._read(path)
        result = self.refactor_file(path, apply=False)
        return tuple(
            difflib.unified_diff(source.splitlines(),
                                 result.splitlines(),
                                 fromfile=str(path)))
Ejemplo n.º 9
0
def main(argv: Optional[Sequence[str]] = None) -> int:
    parser = argparse.ArgumentParser(
        prog="unimport",
        description=C.DESCRIPTION,
        epilog="Get rid of all unused imports 🥳",
    )
    exclusive_group = parser.add_mutually_exclusive_group(required=False)
    parser.add_argument(
        "sources",
        default=default_config.sources,
        nargs="*",
        help="files and folders to find the unused imports.",
        action="store",
        type=Path,
    )
    parser.add_argument(
        "-c",
        "--config",
        default=".",
        help="read configuration from PATH.",
        metavar="PATH",
        action="store",
        type=Path,
    )
    parser.add_argument(
        "--include",
        help="file include pattern.",
        metavar="include",
        action="store",
        default=default_config.include,
        type=str,
    )
    parser.add_argument(
        "--exclude",
        help="file exclude pattern.",
        metavar="exclude",
        action="store",
        default=default_config.exclude,
        type=str,
    )
    parser.add_argument(
        "--gitignore",
        action="store_true",
        help="exclude .gitignore patterns. if present.",
        default=default_config.gitignore,
    )
    parser.add_argument(
        "--include-star-import",
        action="store_true",
        help="Include star imports during scanning and refactor.",
        default=default_config.include_star_import,
    )
    parser.add_argument(
        "-d",
        "--diff",
        action="store_true",
        help="Prints a diff of all the changes unimport would make to a file.",
        default=default_config.diff,
    )
    exclusive_group.add_argument(
        "-r",
        "--remove",
        action="store_true",
        help="remove unused imports automatically.",
        default=default_config.remove,
    )
    exclusive_group.add_argument(
        "-p",
        "--permission",
        action="store_true",
        help="Refactor permission after see diff.",
        default=default_config.permission,
    )
    parser.add_argument(
        "--requirements",
        action="store_true",
        help=
        "Include requirements.txt file, You can use it with all other arguments",
        default=default_config.requirements,
    )
    parser.add_argument(
        "--check",
        action="store_true",
        help="Prints which file the unused imports are in.",
        default=default_config.check,
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version=f"Unimport {C.VERSION}",
        help="Prints version of unimport",
    )
    argv = argv if argv is not None else sys.argv[1:]
    args = parser.parse_args(argv)
    config = (Config(args.config).parse() if args.config
              and args.config.name in CONFIG_FILES else default_config)
    config = config.merge(**vars(args))
    unused_modules = set()
    used_packages: Set[str] = set()
    for source_path in config.sources:
        for py_path in utils.list_paths(source_path, config.include,
                                        config.exclude):
            source, encoding = utils.read(py_path)
            scanner = Scanner(
                source=source,
                path=py_path,
                include_star_import=config.include_star_import,
            )
            scanner.traverse()
            unused_imports = list(scanner.get_unused_imports())
            unused_modules.update({imp.name for imp in unused_imports})
            used_packages.update(
                utils.get_used_packages(scanner.imports, unused_imports))
            if config.check:
                show(unused_imports, py_path)
            if any((config.diff, config.remove)):
                refactor_result = refactor_string(
                    source=source,
                    unused_imports=unused_imports,
                )
            if config.diff:
                diff = utils.diff(
                    source=source,
                    refactor_result=refactor_result,
                    fromfile=py_path,
                )
                exists_diff = bool(diff)
                if exists_diff:
                    print(color.difference(diff))
            if config.permission and exists_diff:
                action = input(
                    f"Apply suggested changes to '{color.paint(str(py_path), color.YELLOW)}' [Y/n/q] ? >"
                ).lower()
                if action == "q":
                    return 1
                elif utils.actiontobool(action):
                    config = config._replace(remove=True)
            if config.remove and source != refactor_result:
                py_path.write_text(refactor_result, encoding=encoding)
                print(
                    f"Refactoring '{color.paint(str(py_path), color.GREEN)}'")
            scanner.clear()
    if not unused_modules and config.check:
        print(
            color.paint(
                "✨ Congratulations there is no unused import in your project. ✨",
                color.GREEN,
            ))
    if config.requirements:
        for requirements in Path(".").glob("requirements*.txt"):
            source = requirements.read_text()
            copy_source = source.splitlines().copy()
            for index, requirement in enumerate(source.splitlines()):
                module_name = utils.package_name_from_metadata(
                    requirement.split("==")[0])
                if module_name is None:
                    print(color.paint(requirement + " not found", color.RED))
                    continue
                if module_name not in used_packages:
                    copy_source.remove(requirement)
                    if config.check:
                        print(
                            f"{color.paint(requirement, color.CYAN)} at "
                            f"{color.paint(requirements.as_posix(), color.CYAN)}:{color.paint(str(index + 1), color.CYAN)}"
                        )
            refactor_result = "\n".join(copy_source)
            if config.diff:
                diff = utils.diff(
                    source=source,
                    refactor_result=refactor_result,
                    fromfile=requirements,
                )
                exists_diff = bool(diff)
                if exists_diff:
                    print(color.difference(diff))
            if config.permission and exists_diff:
                action = input(
                    f"Apply suggested changes to '{color.paint(requirements.as_posix(), color.CYAN)}' [Y/n/q] ? >"
                ).lower()
                if action == "q":
                    return 1
                if utils.actiontobool(action):
                    config = config._replace(remove=True)
            if config.remove:
                requirements.write_text(refactor_result)
                print(
                    f"Refactoring '{color.paint(requirements.as_posix(), color.CYAN)}'"
                )
    if unused_modules:
        return 1
    else:
        return 0
Ejemplo n.º 10
0
class Session:
    GLOB_PATTERN = "**/*.py"
    INCLUDE_REGEX_PATTERN = "\\.(py)$"
    EXCLUDE_REGEX_PATTERN = "^$"

    def __init__(
        self,
        config_file: Optional[Path] = None,
        *,
        include_star_import: bool = False,
        show_error: bool = False
    ) -> None:
        self.show_error = show_error
        self.config = (
            Config(config_file)
            if config_file and config_file.name in CONFIG_FILES
            else None
        )
        self.scanner = Scanner(
            include_star_import=include_star_import, show_error=self.show_error
        )

    def read(self, path: Path) -> Tuple[str, str]:
        try:
            with tokenize.open(path) as stream:
                source = stream.read()
                encoding = stream.encoding
        except (OSError, SyntaxError) as err:
            if self.show_error:
                print(Color(str(err)).red)
            return "", "utf-8"
        return source, encoding

    def list_paths(
        self,
        start: Path,
        include: Optional[str] = None,
        exclude: Optional[str] = None,
    ) -> Iterator[Path]:
        include_regex, exclude_regex = (
            re.compile(include or self.INCLUDE_REGEX_PATTERN),
            re.compile(exclude or self.EXCLUDE_REGEX_PATTERN),
        )
        file_names: Iterable[Path]
        if start.is_dir():
            file_names = start.glob(self.GLOB_PATTERN)
        else:
            file_names = [start]
        yield from filter(
            lambda filename: include_regex.search(str(filename))
            and not exclude_regex.search(str(filename)),
            file_names,
        )

    def refactor(self, source: str) -> str:
        self.scanner.run_visit(source)
        refactor = refactor_string(
            source=source,
            unused_imports=self.scanner.unused_imports,
            show_error=self.show_error,
        )
        self.scanner.clear()
        return refactor

    def refactor_file(self, path: Path, apply: bool = False) -> str:
        source, encoding = self.read(path)
        result = self.refactor(source)
        if apply:
            path.write_text(result, encoding=encoding)
        return result

    def diff(self, source: str) -> Tuple[str, ...]:
        return tuple(
            difflib.unified_diff(
                source.splitlines(), self.refactor(source).splitlines()
            )
        )

    def diff_file(self, path: Path) -> Tuple[str, ...]:
        source, _ = self.read(path)
        result = self.refactor_file(path, apply=False)
        return tuple(
            difflib.unified_diff(
                source.splitlines(), result.splitlines(), fromfile=str(path)
            )
        )
Ejemplo n.º 11
0
 def setUp(self):
     self.scanner = Scanner()