Ejemplo n.º 1
0
def load_config(filename: str) -> Dict:
    defaults = {
        "line_length": 88,
        "fast": False,
        "py36": False,
        "pyi": False,
        "skip_string_normalization": False,
        "skip_numeric_underscore_normalization": False,
    }

    root = black.find_project_root((filename, ))

    pyproject_filename = root / "pyproject.toml"

    if not pyproject_filename.is_file():
        return defaults

    try:
        pyproject_toml = toml.load(str(pyproject_filename))
    except (toml.TomlDecodeError, OSError):
        return defaults

    config = pyproject_toml.get("tool", {}).get("black", {})
    config = {
        key.replace("--", "").replace("-", "_"): value
        for key, value in config.items()
    }

    return {**defaults, **config}
Ejemplo n.º 2
0
    def isort(self, source):
        if self.only_when_a_project_config_is_found:
            has_conf = False
            root = black.find_project_root((self.file_path,))
            path = root / "pyproject.toml"
            if path.is_file():
                pyproject_toml = toml.load(str(path))
                config = pyproject_toml.get("tool", {}).get("isort", {})
                if config:
                    has_conf = True
            if not has_conf:
                path = root / ".isort.cfg"
                if path.is_file():
                    has_conf = True
            if not has_conf:
                path = root / "setup.cfg"
                if path.is_file():
                    import configparser

                    config = configparser.ConfigParser()
                    with open(path) as fp:
                        config.read_file(fp)
                    if config.has_section('isort') or config.has_section(
                        'tool:isort'
                    ):
                        has_conf = True
            if not has_conf:
                return source

        return isort.code(
            code=source,
            config=isort.settings.Config(settings_path=root),
        )
Ejemplo n.º 3
0
def read_config_file(ctx, param, value):
    if not value:
        root = black.find_project_root(ctx.params.get('src', ()))
        path = root / 'setup.cfg'
        if path.is_file():
            value = str(path)
        else:
            return None

    config = configparser.ConfigParser()
    config.read(value)
    try:
        config = dict(config['tool:brunette'])
    except KeyError:
        return None
    if not config:
        return None

    if ctx.default_map is None:
        ctx.default_map = {}

    for k, v in config.items():
        k = k.replace('--', '').replace('-', '_')
        for command_param in ctx.command.params:
            if command_param.name == k:
                if command_param.multiple:
                    v = v.split(',')
                break
        else:
            raise KeyError('Invalid paramater: {}'.format(k))

        ctx.default_map[k] = v

    return value
Ejemplo n.º 4
0
    def _file_mode(self):
        """Return black.FileMode object, using local pyproject.toml as needed."""
        if self.override_config:
            return self.override_config

        # Unless using override, we look for pyproject.toml
        project_root = black.find_project_root(
            ("." if self.filename in self.STDIN_NAMES else self.filename, ))
        path = project_root / "pyproject.toml"

        if path in black_config:
            # Already loaded
            LOG.debug("flake8-black: %s using pre-loaded %s", self.filename,
                      path)
            return black_config[path]
        elif path.is_file():
            # Use this pyproject.toml for this python file,
            # (unless configured with global override config)
            # This should be thread safe - does not matter even if
            # two workers load and cache this file at the same time
            black_config[path] = load_black_mode(path)
            LOG.debug("flake8-black: %s using newly loaded %s", self.filename,
                      path)
            return black_config[path]
        else:
            # No project specific file, use default
            LOG.debug("flake8-black: %s using defaults", self.filename)
            return black_config[None]
Ejemplo n.º 5
0
def format_some(filenames, **config):
    check = config.get("check", False)
    root = find_project_root(filenames)
    path = root / "pyproject.toml"
    if path.is_file():
        value = str(path)
        pyproject_toml = toml.load(value)
        new_config = pyproject_toml.get("tool", {}).get("black", {})
        config = dict(new_config, **config)

    proc = BlackPreprocessor(
        line_length=config.get("line_length", DEFAULT_LINE_LENGTH),
        target_versions=set(config.get("target_version", [])),
        string_normalization=not config.get("skip_string_normalization",
                                            False),
    )

    count = 0
    for filename in filenames:
        changed = format_one(proc, filename, check=check)
        if changed:
            count += 1
            if check:
                print("Invalid format: {0}".format(filename))
            else:
                print("Formatted: {0}".format(filename))

    return count
Ejemplo n.º 6
0
def read_config_file(ctx, param, value):
    if not value:
        root = black.find_project_root(ctx.params.get('src', ()))
        path = root / 'setup.cfg'
        if path.is_file():
            value = str(path)
        else:
            return None

    config = configparser.ConfigParser()
    config.read(value)
    try:
        config = dict(config['tool:brunette'])
    except KeyError:
        return None
    if not config:
        return None

    try:
        if config['single-quotes'].lower() == 'true':
            black.normalize_string_quotes = normalize_string_quotes
    except KeyError:
        pass

    if ctx.default_map is None:
        ctx.default_map = {}
    ctx.default_map.update(  # type: ignore  # bad types in .pyi
        {k.replace('--', '').replace('-', '_'): v
         for k, v in config.items()})
    return value
Ejemplo n.º 7
0
    def __init__(self,
                 python_version: PythonVersion,
                 settings_path: Optional[Path] = None):
        if not settings_path:
            settings_path = Path().resolve()

        root = black.find_project_root((settings_path, ))
        path = root / "pyproject.toml"
        if path.is_file():
            value = str(path)
            pyproject_toml = toml.load(value)
            config = pyproject_toml.get("tool", {}).get("black", {})
        else:
            config = {}
        self.back_mode = black.FileMode(
            target_versions={BLACK_PYTHON_VERSION[python_version]},
            line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
            string_normalization=not config.get("skip-string-normalization",
                                                True),
        )

        self.settings_path: str = str(settings_path)
        if isort.__version__.startswith('4.'):
            self.isort_config = None
        else:
            self.isort_config = isort.Config(settings_path=self.settings_path)
Ejemplo n.º 8
0
def load_config(srcs: Iterable[str]) -> DarkerConfig:
    """Find and load Darker configuration from given path or pyproject.toml

    :param srcs: File(s) and directory/directories which will be processed. Their paths
                 are used to look for the ``pyproject.toml`` configuration file.

    """
    path = find_project_root(tuple(srcs or ["."])) / "pyproject.toml"
    if path.is_file():
        pyproject_toml = toml.load(path)
        config: DarkerConfig = pyproject_toml.get("tool", {}).get("darker", {}) or {}
        replace_log_level_name(config)
        return config
    return {}
Ejemplo n.º 9
0
    def _load_black_config(self):
        source_path = (self.filename if self.filename not in self.STDIN_NAMES
                       else Path.cwd().as_posix())
        project_root = black.find_project_root((Path(source_path), ))
        path = project_root / "pyproject.toml"

        if path.is_file():
            pyproject_toml = toml.load(str(path))
            config = pyproject_toml.get("tool", {}).get("black", {})
            return {
                k.replace("--", "").replace("-", "_"): v
                for k, v in config.items()
            }
        return None
Ejemplo n.º 10
0
def apply_isort(
    content: TextDocument,
    src: Path,
    config: Optional[str] = None,
    line_length: Optional[int] = None,
) -> TextDocument:
    isort_args = IsortArgs()
    if config:
        isort_args["settings_file"] = config
    else:
        isort_args["settings_path"] = str(find_project_root((str(src), )))
    if line_length:
        isort_args["line_length"] = line_length

    logger.debug("isort.code(code=..., {})".format(", ".join(
        f"{k}={v!r}" for k, v in isort_args.items())))
    return TextDocument.from_str(isort_code(code=content.string, **isort_args),
                                 encoding=content.encoding)
Ejemplo n.º 11
0
def apply_isort(
    content: str,
    src: Path,
    config: Optional[str] = None,
    line_length: Optional[int] = None,
) -> str:
    isort_args = IsortArgs()
    if config:
        isort_args["settings_file"] = config
    else:
        isort_args["settings_path"] = str(find_project_root((str(src), )))
    if line_length:
        isort_args["line_length"] = line_length

    logger.debug("isort.code(code=..., {})".format(", ".join(
        f"{k}={v!r}" for k, v in isort_args.items())))
    result: str = isort.code(code=content, **isort_args)
    return result
Ejemplo n.º 12
0
def load_config(filename: str) -> Dict:
    defaults = {
        "line_length": 88,
        "fast": False,
        "pyi": filename.endswith(".pyi"),
        "skip_string_normalization": False,
        "target_version": set(),
    }

    root = black.find_project_root((filename,))

    pyproject_filename = root / "pyproject.toml"

    if not pyproject_filename.is_file():
        return defaults

    try:
        pyproject_toml = toml.load(str(pyproject_filename))
    except (toml.TomlDecodeError, OSError):
        return defaults

    file_config = pyproject_toml.get("tool", {}).get("black", {})
    file_config = {
        key.replace("--", "").replace("-", "_"): value
        for key, value in file_config.items()
    }

    config = {
        key: file_config.get(key, default_value)
        for key, default_value in defaults.items()
    }

    if file_config.get("target_version"):
        target_version = set(
            black.TargetVersion[x.upper()] for x in file_config["target_version"]
        )
    elif file_config.get("py36"):
        target_version = black.PY36_VERSIONS
    else:
        target_version = set()

    config["target_version"] = target_version

    return config
Ejemplo n.º 13
0
def apply_black(code: str, python_version: PythonVersion) -> str:

    root = black.find_project_root((Path().resolve(),))
    path = root / "pyproject.toml"
    if path.is_file():
        value = str(path)
        pyproject_toml = toml.load(value)
        config = pyproject_toml.get("tool", {}).get("black", {})
    else:
        config = {}

    return black.format_str(
        code,
        mode=black.FileMode(
            target_versions={BLACK_PYTHON_VERSION[python_version]},
            line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
            string_normalization=not config.get("skip-string-normalization", True),
        ),
    )
Ejemplo n.º 14
0
    def black(self, source):
        has_conf = False
        config = {}
        if self.file_path:
            root = black.find_project_root((self.file_path,))
            path = root / "pyproject.toml"
            if path.is_file():
                pyproject_toml = toml.load(str(path))
                config = pyproject_toml.get("tool", {}).get("black", {})
                if config:
                    has_conf = True
                config = {
                    k.replace("--", "").replace("-", "_"): v
                    for k, v in config.items()
                }

        if self.only_when_a_project_config_is_found and not has_conf:
            return source

        line_length = config.pop("line_length", black.DEFAULT_LINE_LENGTH)
        versions = (
            set(
                [
                    black.TargetVersion[v.upper()]
                    for v in config.get('target_version')
                ]
            )
            if config.get('target_version')
            else set()
        )
        pyi = config.get("pyi")
        skip_string_normalization = config.get("skip_string_normalization")

        mode = black.FileMode(
            target_versions=versions,
            line_length=line_length,
            is_pyi=pyi,
            string_normalization=not skip_string_normalization,
        )

        return black.format_str(source, mode=mode)
Ejemplo n.º 15
0
def collect_files(src, include, exclude):
    root = black.find_project_root(tuple(src))
    gitignore = black.get_gitignore(root)
    report = black.Report()

    force_exclude = ""

    for path in src:
        if path.is_dir():
            yield from gen_python_files(
                path.iterdir(),
                root,
                include,
                exclude,
                force_exclude,
                report,
                gitignore,
            )
        elif path.is_file() or str(path) == "-":
            yield path
        else:
            print(f"invalid path: {path}", file=sys.stderr)
Ejemplo n.º 16
0
def get_source_files(paths: Iterable[str]) -> Iterable[Path]:
    report = Report()
    root = find_project_root((f for f in paths))
    sources: Set[Path] = set()
    for filename in paths:
        path = Path(filename)
        if path.is_dir():
            sources.update(
                gen_python_files_in_dir(
                    path=path,
                    root=root,
                    include=INCLUDES,
                    exclude=EXCLUDES,
                    report=report,
                    gitignore=get_gitignore(root),
                ))
        elif path.is_file():
            sources.add(path)
        else:
            print(f"Error: invalid path: {path}")
            exit(1)
    return sources
Ejemplo n.º 17
0
def main(src):
    # Path handling inspired on the implementation of Black: https://github.com/psf/black
    root = find_project_root(src)
    gitignore = get_gitignore(root)
    sources: Set[Path] = set()

    for s in src:
        p = Path(s)
        if p.is_dir():
            sources.update(get_files_in_dir(p, root, gitignore))
        elif p.is_file() or s == "-":
            # if a file was explicitly given, we don't care about its extension
            sources.add(p)
        else:
            raise RuntimeError(f"invalid path: {s}")

    ok = True
    for source in sources:
        file_ok = check_file(source)
        if not file_ok:
            ok = False

    if not ok:
        sys.exit(1)
Ejemplo n.º 18
0
def main(args: Optional[Sequence[str]] = None) -> Exit:
    """Main function."""

    # add cli completion support
    argcomplete.autocomplete(arg_parser)

    if args is None:
        args = sys.argv[1:]

    namespace: Namespace = arg_parser.parse_args(args)

    if namespace.version:  # pragma: no cover
        from datamodel_code_generator.version import version

        print(version)
        exit(0)

    root = black.find_project_root((Path().resolve(),))
    pyproject_toml_path = root / "pyproject.toml"
    if pyproject_toml_path.is_file():
        pyproject_toml: Dict[str, Any] = {
            k.replace('-', '_'): v
            for k, v in toml.load(str(pyproject_toml_path))
            .get('tool', {})
            .get('datamodel-codegen', {})
            .items()
        }
    else:
        pyproject_toml = {}

    config = Config.parse_obj(pyproject_toml)
    config.merge_args(namespace)

    if config.input is not None:
        input_name: str = config.input.name  # type: ignore
        input_text: str = config.input.read()
    else:
        input_name = '<stdin>'
        input_text = sys.stdin.read()

    if config.debug:  # pragma: no cover
        enable_debug_message()

    extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]]
    if config.extra_template_data is not None:
        with config.extra_template_data as data:
            try:
                extra_template_data = json.load(
                    data, object_hook=lambda d: defaultdict(dict, **d)
                )
            except json.JSONDecodeError as e:
                print(f"Unable to load extra template data: {e}", file=sys.stderr)
                return Exit.ERROR
    else:
        extra_template_data = None

    if config.aliases is not None:
        with config.aliases as data:
            try:
                aliases = json.load(data)
            except json.JSONDecodeError as e:
                print(f"Unable to load alias mapping: {e}", file=sys.stderr)
                return Exit.ERROR
        if not isinstance(aliases, Dict) or not all(
            isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
        ):
            print(
                'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
                file=sys.stderr,
            )
            return Exit.ERROR
    else:
        aliases = None

    try:
        generate(
            input_name=input_name,
            input_text=input_text,
            input_file_type=config.input_file_type,
            output=config.output,
            target_python_version=config.target_python_version,
            base_class=config.base_class,
            custom_template_dir=config.custom_template_dir,
            validation=config.validation,
            field_constraints=config.field_constraints,
            snake_case_field=config.snake_case_field,
            strip_default_none=config.strip_default_none,
            extra_template_data=extra_template_data,
            aliases=aliases,
            disable_timestamp=config.disable_timestamp,
            allow_population_by_field_name=config.allow_population_by_field_name,
            use_default_on_required_field=config.use_default,
        )
        return Exit.OK
    except Error as e:
        print(str(e), file=sys.stderr)
        return Exit.ERROR
    except Exception:
        import traceback

        print(traceback.format_exc(), file=sys.stderr)
        return Exit.ERROR
Ejemplo n.º 19
0
def main(
    ctx: click.Context,
    code: Optional[str],
    line_length: int,
    target_version: List[TargetVersion],
    check: bool,
    diff: bool,
    fast: bool,
    pyi: bool,
    py36: bool,
    skip_string_normalization: bool,
    single_quotes: bool,
    quiet: bool,
    verbose: bool,
    include: str,
    exclude: str,
    src: Tuple[str],
    config: Optional[str],
) -> None:
    """The uncompromising code formatter."""
    write_back = WriteBack.from_configuration(check=check, diff=diff)
    if target_version:
        if py36:
            err('Cannot use both --target-version and --py36')
            ctx.exit(2)
        else:
            versions = set(target_version)
    elif py36:
        err('--py36 is deprecated and will be removed in a future version. '
            'Use --target-version py36 instead.')
        versions = PY36_VERSIONS
    else:
        # We'll autodetect later.
        versions = set()
    mode = FileMode(
        target_versions=versions,
        line_length=line_length,
        is_pyi=pyi,
        string_normalization=not skip_string_normalization,
    )

    if single_quotes:
        black.normalize_string_quotes = patched_normalize_string_quotes

    if config and verbose:
        out(f'Using configuration from {config}.', bold=False, fg='blue')
    if code is not None:
        print(format_str(code, mode=mode))
        ctx.exit(0)
    try:
        include_regex = re_compile_maybe_verbose(include)
    except re.error:
        err(f'Invalid regular expression for include given: {include!r}')
        ctx.exit(2)
    try:
        exclude_regex = re_compile_maybe_verbose(exclude)
    except re.error:
        err(f'Invalid regular expression for exclude given: {exclude!r}')
        ctx.exit(2)
    report = Report(check=check, quiet=quiet, verbose=verbose)
    root = find_project_root(src)
    sources: Set[Path] = set()
    path_empty(src=src, quiet=quiet, verbose=verbose, ctx=ctx, msg=None)
    for s in src:
        p = Path(s)
        if p.is_dir():
            sources.update(
                gen_python_files_in_dir(
                    p,
                    root,
                    include_regex,
                    exclude_regex,
                    report,
                    get_gitignore(root),
                ))
        elif p.is_file() or s == '-':
            # if a file was explicitly given, we don't care about its extension
            sources.add(p)
        else:
            err(f'invalid path: {s}')
    if len(sources) == 0:
        if verbose or not quiet:
            out('No Python files are present to be formatted. Nothing to do 😴')
        ctx.exit(0)

    reformat_many(
        sources=sources,
        fast=fast,
        write_back=write_back,
        mode=mode,
        report=report,
    )

    if verbose or not quiet:
        out('Oh no! 💥 💔 💥' if report.return_code else 'All done! ✨ 🍰 ✨')
        click.secho(str(report), err=True)
    ctx.exit(report.return_code)
Ejemplo n.º 20
0
def cli(
    ctx: click.Context,
    line_length: int,
    check: bool,
    include: str,
    exclude: str,
    quiet: bool,
    verbose: bool,
    clear_output: bool,
    src: Tuple[str],
    config: Optional[str],
) -> None:
    """
    The uncompromising code formatter, for Jupyter notebooks.
    """
    write_back = black.WriteBack.from_configuration(check=check, diff=False)
    mode = black.FileMode.from_configuration(
        py36=True,
        pyi=False,
        skip_string_normalization=False,
        skip_numeric_underscore_normalization=False,
    )

    if config and verbose:
        black.out(f"Using configuration from {config}.", bold=False, fg="blue")

    try:
        include_regex = black.re_compile_maybe_verbose(include)
    except re.error:
        black.err(f"Invalid regular expression for include given: {include!r}")
        ctx.exit(2)
    try:
        exclude_regex = black.re_compile_maybe_verbose(exclude)
    except re.error:
        black.err(f"Invalid regular expression for exclude given: {exclude!r}")
        ctx.exit(2)

    report = black.Report(check=check, quiet=quiet, verbose=verbose)
    root = black.find_project_root(src)
    sources: Set[Path] = set()
    for s in src:
        p = Path(s)
        if p.is_dir():
            sources.update(
                black.gen_python_files_in_dir(
                    p, root, include_regex, exclude_regex, report
                )
            )
        elif p.is_file() or s == "-":
            # if a file was explicitly given, we don't care about its extension
            sources.add(p)
        else:
            black.err(f"invalid path: {s}")
    if len(sources) == 0:
        if verbose or not quiet:
            black.out("No paths given. Nothing to do.")
        ctx.exit(0)

    for source in sources:
        reformat_one(
            src=source,
            line_length=line_length,
            write_back=write_back,
            mode=mode,
            clear_output=clear_output,
            report=report,
            quiet=quiet,
            verbose=verbose,
        )

    if verbose or not quiet:
        black.out(f"All done!")
        click.secho(str(report), err=True)
    ctx.exit(report.return_code)
Ejemplo n.º 21
0
def main(args: Optional[Sequence[str]] = None) -> Exit:
    """Main function."""

    # add cli completion support
    argcomplete.autocomplete(arg_parser)

    if args is None:
        args = sys.argv[1:]

    namespace: Namespace = arg_parser.parse_args(args)

    if namespace.version:  # pragma: no cover
        from datamodel_code_generator.version import version

        print(version)
        exit(0)

    root = black.find_project_root((Path().resolve(), ))
    pyproject_toml_path = root / "pyproject.toml"
    if pyproject_toml_path.is_file():
        pyproject_toml: Dict[str, Any] = {
            k.replace('-', '_'): v
            for k, v in toml.load(str(pyproject_toml_path)).get(
                'tool', {}).get('datamodel-codegen', {}).items()
        }
    else:
        pyproject_toml = {}

    try:
        config = Config.parse_obj(pyproject_toml)
        config.merge_args(namespace)
    except Error as e:
        print(e.message, file=sys.stderr)
        return Exit.ERROR

    if not is_supported_in_black(
            config.target_python_version):  # pragma: no cover
        print(
            f"Installed black doesn't support Python version {config.target_python_version.value}.\n"
            f"You have to install a newer black.\n"
            f"Installed black version: {black.__version__}",
            file=sys.stderr,
        )
        return Exit.ERROR

    if config.debug:  # pragma: no cover
        enable_debug_message()

    extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]]
    if config.extra_template_data is None:
        extra_template_data = None
    else:
        with config.extra_template_data as data:
            try:
                extra_template_data = json.load(
                    data, object_hook=lambda d: defaultdict(dict, **d))
            except json.JSONDecodeError as e:
                print(f"Unable to load extra template data: {e}",
                      file=sys.stderr)
                return Exit.ERROR

    if config.aliases is None:
        aliases = None
    else:
        with config.aliases as data:
            try:
                aliases = json.load(data)
            except json.JSONDecodeError as e:
                print(f"Unable to load alias mapping: {e}", file=sys.stderr)
                return Exit.ERROR
        if not isinstance(aliases, dict) or not all(
                isinstance(k, str) and isinstance(v, str)
                for k, v in aliases.items()):
            print(
                'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
                file=sys.stderr,
            )
            return Exit.ERROR

    try:
        generate(
            input_=config.url or config.input or sys.stdin.read(),
            input_file_type=config.input_file_type,
            output=config.output,
            target_python_version=config.target_python_version,
            base_class=config.base_class,
            custom_template_dir=config.custom_template_dir,
            validation=config.validation,
            field_constraints=config.field_constraints,
            snake_case_field=config.snake_case_field,
            strip_default_none=config.strip_default_none,
            extra_template_data=extra_template_data,
            aliases=aliases,
            disable_timestamp=config.disable_timestamp,
            allow_population_by_field_name=config.
            allow_population_by_field_name,
            apply_default_values_for_required_fields=config.use_default,
            force_optional_for_required_fields=config.force_optional,
            class_name=config.class_name,
            use_standard_collections=config.use_standard_collections,
            use_schema_description=config.use_schema_description,
            reuse_model=config.reuse_model,
            encoding=config.encoding,
            enum_field_as_literal=config.enum_field_as_literal,
            set_default_enum_member=config.set_default_enum_member,
            strict_nullable=config.strict_nullable,
            use_generic_container_types=config.use_generic_container_types,
            enable_faux_immutability=config.enable_faux_immutability,
            disable_appending_item_suffix=config.disable_appending_item_suffix,
            strict_types=config.strict_types,
            empty_enum_field_name=config.empty_enum_field_name,
            field_extra_keys=config.field_extra_keys,
            field_include_all_keys=config.field_include_all_keys,
        )
        return Exit.OK
    except InvalidClassNameError as e:
        print(f'{e} You have to set `--class-name` option', file=sys.stderr)
        return Exit.ERROR
    except Error as e:
        print(str(e), file=sys.stderr)
        return Exit.ERROR
    except Exception:
        import traceback

        print(traceback.format_exc(), file=sys.stderr)
        return Exit.ERROR
Ejemplo n.º 22
0
def find_pyproject_toml(start_path: Iterable[str]) -> Optional[str]:
    root = find_project_root(start_path)
    config_file = root / "pyproject.toml"
    return str(config_file) if config_file.is_file() else None
Ejemplo n.º 23
0
async def api(
    *,
    src: Iterable[str],
    work_dir: str,
    line_length: int = black.DEFAULT_LINE_LENGTH,
    check: bool = False,
    diff: bool = False,
    fast: bool = False,
    pyi: bool = False,
    py36: bool = False,
    skip_string_normalization: bool = False,
    quiet: bool = False,
    verbose: bool = False,
    include: str = black.DEFAULT_INCLUDES,
    exclude: str = black.DEFAULT_EXCLUDES,
    config: Optional[str] = None,
) -> int:
    """The uncompromising code formatter."""
    src = tuple(src)
    work_dir = Path(work_dir)
    loop = asyncio.get_event_loop()
    write_back = black.WriteBack.from_configuration(check=check, diff=diff)
    mode = black.FileMode.from_configuration(
        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
    )
    if config and verbose:
        black.out(f"Using configuration from {config}.", bold=False, fg="blue")
    try:
        include_regex = black.re_compile_maybe_verbose(include)
    except re.error:
        black.err(f"Invalid regular expression for include given: {include!r}")
        return 2
    try:
        exclude_regex = black.re_compile_maybe_verbose(exclude)
    except re.error:
        black.err(f"Invalid regular expression for exclude given: {exclude!r}")
        return 2
    report = black.Report(check=check, quiet=quiet, verbose=verbose)
    root = black.find_project_root((work_dir,))
    sources: Set[Path] = set()
    for s in src:
        p = work_dir / Path(s)
        if p.is_dir():
            sources.update(
                black.gen_python_files_in_dir(
                    p, root, include_regex, exclude_regex, report
                )
            )
        elif p.is_file() or s == "-":
            # if a file was explicitly given, we don't care about its extension
            sources.add(p)
        else:
            black.err(f"invalid path: {s}")
    if len(sources) == 0:
        if verbose or not quiet:
            black.out("No paths given. Nothing to do 😴")
        return 0

    if len(sources) == 1:
        black.reformat_one(
            src=sources.pop(),
            line_length=line_length,
            fast=fast,
            write_back=write_back,
            mode=mode,
            report=report,
        )
    else:
        await black.schedule_formatting(
            sources=sources,
            line_length=line_length,
            fast=fast,
            write_back=write_back,
            mode=mode,
            report=report,
            executor=PROCESS_POOL,
            loop=loop,
        )
    if verbose or not quiet:
        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
        black.out(f"All done! {bang}")
        black.secho(str(report), err=True)
    return report.return_code