Beispiel #1
0
def test_determine_specified_tool_names() -> None:
    class StyleReq(StyleRequest):
        name = "my-tool"

    with pytest.raises(ValueError) as exc:
        determine_specified_tool_names(
            "fake-goal",
            only_option=["bad"],
            all_style_requests=[StyleReq],
            extra_valid_names=["extra-tool"],
        )
    assert ("Unrecognized name with the option `--fake-goal-only`: 'bad'\n\n"
            "All valid names: ['extra-tool', 'my-tool']") in str(exc.value)
Beispiel #2
0
async def check(
    console: Console,
    workspace: Workspace,
    targets: FilteredTargets,
    dist_dir: DistDir,
    union_membership: UnionMembership,
    check_subsystem: CheckSubsystem,
) -> Check:
    request_types = cast("Iterable[type[StyleRequest]]", union_membership[CheckRequest])
    specified_names = determine_specified_tool_names("check", check_subsystem.only, request_types)

    requests = tuple(
        request_type(
            request_type.field_set_type.create(target)
            for target in targets
            if (
                request_type.name in specified_names
                and request_type.field_set_type.is_applicable(target)
            )
        )
        for request_type in request_types
    )
    all_results = await MultiGet(
        Get(CheckResults, CheckRequest, request) for request in requests if request.field_sets
    )

    def get_name(res: CheckResults) -> str:
        return res.checker_name

    write_reports(
        all_results,
        workspace,
        dist_dir,
        goal_name=CheckSubsystem.name,
        get_name=get_name,
    )

    exit_code = 0
    if all_results:
        console.print_stderr("")
    for results in sorted(all_results, key=lambda results: results.checker_name):
        if results.skipped:
            continue
        elif results.exit_code == 0:
            sigil = console.sigil_succeeded()
            status = "succeeded"
        else:
            sigil = console.sigil_failed()
            status = "failed"
            exit_code = results.exit_code
        console.print_stderr(f"{sigil} {results.checker_name} {status}.")

    return Check(exit_code)
Beispiel #3
0
async def lint(
    console: Console,
    workspace: Workspace,
    specs: Specs,
    lint_subsystem: LintSubsystem,
    union_membership: UnionMembership,
    dist_dir: DistDir,
) -> Lint:
    lint_target_request_types = cast(
        "Iterable[type[LintTargetsRequest]]", union_membership.get(LintTargetsRequest)
    )
    fmt_target_request_types = cast("Iterable[type[FmtRequest]]", union_membership.get(FmtRequest))
    file_request_types = cast(
        "Iterable[type[LintFilesRequest]]", union_membership[LintFilesRequest]
    )

    _check_ambiguous_request_names(
        *lint_target_request_types, *fmt_target_request_types, *file_request_types
    )

    specified_names = determine_specified_tool_names(
        "lint",
        lint_subsystem.only,
        [*lint_target_request_types, *fmt_target_request_types],
        extra_valid_names={request.name for request in file_request_types},
    )

    def is_specified(request_type: type[StyleRequest] | type[LintFilesRequest]):
        return request_type.name in specified_names

    lint_target_request_types = filter(is_specified, lint_target_request_types)
    fmt_target_request_types = filter(is_specified, fmt_target_request_types)
    file_request_types = filter(is_specified, file_request_types)

    _get_targets = Get(
        FilteredTargets,
        Specs,
        specs if lint_target_request_types or fmt_target_request_types else Specs.empty(),
    )
    _get_specs_paths = Get(SpecsPaths, Specs, specs if file_request_types else Specs.empty())
    targets, specs_paths = await MultiGet(_get_targets, _get_specs_paths)

    def batch(field_sets: Iterable[FieldSet]) -> Iterator[list[FieldSet]]:
        partitions = partition_sequentially(
            field_sets,
            key=lambda fs: fs.address.spec,
            size_target=lint_subsystem.batch_size,
            size_max=4 * lint_subsystem.batch_size,
        )
        for partition in partitions:
            yield partition

    def batch_by_type(
        request_types: Iterable[type[_SR]],
    ) -> tuple[tuple[type[_SR], list[FieldSet]], ...]:
        return tuple(
            (request_type, field_set_batch)
            for request_type in request_types
            for field_set_batch in batch(
                request_type.field_set_type.create(target)
                for target in targets
                if request_type.field_set_type.is_applicable(target)
            )
        )

    lint_target_requests = (
        request_type(batch) for request_type, batch in batch_by_type(lint_target_request_types)
    )

    fmt_requests: Iterable[FmtRequest] = ()
    if not lint_subsystem.skip_formatters:
        batched_fmt_request_pairs = batch_by_type(fmt_target_request_types)
        all_fmt_source_batches = await MultiGet(
            Get(
                SourceFiles,
                SourceFilesRequest(
                    cast(
                        SourcesField,
                        getattr(field_set, "sources", getattr(field_set, "source", None)),
                    )
                    for field_set in batch
                ),
            )
            for _, batch in batched_fmt_request_pairs
        )
        fmt_requests = (
            request_type(
                batch,
                snapshot=source_files_snapshot.snapshot,
            )
            for (request_type, batch), source_files_snapshot in zip(
                batched_fmt_request_pairs, all_fmt_source_batches
            )
        )

    file_requests = (
        tuple(request_type(specs_paths.files) for request_type in file_request_types)
        if specs_paths.files
        else ()
    )

    all_requests = [
        *(Get(LintResults, LintTargetsRequest, request) for request in lint_target_requests),
        *(Get(FmtResult, FmtRequest, request) for request in fmt_requests),
        *(Get(LintResults, LintFilesRequest, request) for request in file_requests),
    ]
    all_batch_results = cast(
        "tuple[LintResults | FmtResult, ...]",
        await MultiGet(all_requests),  # type: ignore[arg-type]
    )

    def key_fn(results: LintResults | FmtResult):
        if isinstance(results, FmtResult):
            return results.formatter_name
        return results.linter_name

    # NB: We must pre-sort the data for itertools.groupby() to work properly.
    sorted_all_batch_results = sorted(all_batch_results, key=key_fn)

    formatter_failed = False

    def coerce_to_lintresult(batch_results: LintResults | FmtResult) -> tuple[LintResult, ...]:
        if isinstance(batch_results, FmtResult):
            nonlocal formatter_failed
            formatter_failed = formatter_failed or batch_results.did_change
            return (
                LintResult(
                    1 if batch_results.did_change else 0,
                    batch_results.stdout,
                    batch_results.stderr,
                ),
            )
        return batch_results.results

    # We consolidate all results for each linter into a single `LintResults`.
    all_results = tuple(
        sorted(
            (
                LintResults(
                    itertools.chain.from_iterable(
                        coerce_to_lintresult(batch_results) for batch_results in results
                    ),
                    linter_name=linter_name,
                )
                for linter_name, results in itertools.groupby(sorted_all_batch_results, key=key_fn)
            ),
            key=key_fn,
        )
    )

    def get_name(res: LintResults) -> str:
        return res.linter_name

    write_reports(
        all_results,
        workspace,
        dist_dir,
        goal_name=LintSubsystem.name,
        get_name=get_name,
    )

    _print_results(
        console,
        all_results,
        formatter_failed,
    )
    return Lint(_get_error_code(all_results))
Beispiel #4
0
async def lint(
    console: Console,
    workspace: Workspace,
    targets: Targets,
    specs_snapshot: SpecsSnapshot,
    lint_subsystem: LintSubsystem,
    union_membership: UnionMembership,
    dist_dir: DistDir,
) -> Lint:
    target_request_types = cast("Iterable[type[LintTargetsRequest]]",
                                union_membership[LintTargetsRequest])
    file_request_types = union_membership[LintFilesRequest]
    specified_names = determine_specified_tool_names(
        "lint",
        lint_subsystem.only,
        target_request_types,
        extra_valid_names={request.name
                           for request in file_request_types},
    )
    target_requests = tuple(
        request_type(
            request_type.field_set_type.create(target) for target in targets
            if (request_type.name in specified_names
                and request_type.field_set_type.is_applicable(target)))
        for request_type in target_request_types)
    file_requests = (tuple(
        request_type(specs_snapshot.snapshot.files)
        for request_type in file_request_types
        if request_type.name in specified_names)
                     if specs_snapshot.snapshot.files else ())

    def address_str(fs: FieldSet) -> str:
        return fs.address.spec

    all_requests = [
        *(Get(LintResults, LintTargetsRequest,
              request.__class__(field_set_batch))
          for request in target_requests if request.field_sets
          for field_set_batch in partition_sequentially(
              request.field_sets,
              key=address_str,
              size_target=lint_subsystem.batch_size,
              size_max=4 * lint_subsystem.batch_size,
          )),
        *(Get(LintResults, LintFilesRequest, request)
          for request in file_requests),
    ]
    all_batch_results = cast(
        "tuple[LintResults, ...]",
        await MultiGet(all_requests),  # type: ignore[arg-type]
    )

    def key_fn(results: LintResults):
        return results.linter_name

    # NB: We must pre-sort the data for itertools.groupby() to work properly.
    sorted_all_batch_results = sorted(all_batch_results, key=key_fn)
    # We consolidate all results for each linter into a single `LintResults`.
    all_results = tuple(
        sorted(
            (LintResults(
                itertools.chain.from_iterable(
                    batch_results.results
                    for batch_results in all_linter_results),
                linter_name=linter_name,
            ) for linter_name, all_linter_results in itertools.groupby(
                sorted_all_batch_results, key=key_fn)),
            key=key_fn,
        ))

    def get_name(res: LintResults) -> str:
        return res.linter_name

    write_reports(
        all_results,
        workspace,
        dist_dir,
        goal_name=LintSubsystem.name,
        get_name=get_name,
    )

    exit_code = 0
    if all_results:
        console.print_stderr("")
    for results in all_results:
        if results.skipped:
            continue
        elif results.exit_code == 0:
            sigil = console.sigil_succeeded()
            status = "succeeded"
        else:
            sigil = console.sigil_failed()
            status = "failed"
            exit_code = results.exit_code
        console.print_stderr(f"{sigil} {results.linter_name} {status}.")

    return Lint(exit_code)
Beispiel #5
0
async def fmt(
    console: Console,
    targets: Targets,
    fmt_subsystem: FmtSubsystem,
    workspace: Workspace,
    union_membership: UnionMembership,
) -> Fmt:
    request_types = union_membership[FmtRequest]
    specified_names = determine_specified_tool_names("fmt", fmt_subsystem.only,
                                                     request_types)

    # Group targets by the sequence of FmtRequests that apply to them.
    targets_by_fmt_request_order = defaultdict(list)
    for target in targets:
        fmt_requests = []
        for fmt_request in request_types:
            valid_name = fmt_request.name in specified_names
            if valid_name and fmt_request.field_set_type.is_applicable(
                    target):  # type: ignore[misc]
                fmt_requests.append(fmt_request)
        if fmt_requests:
            targets_by_fmt_request_order[tuple(fmt_requests)].append(target)

    # Spawn sequential formatting per unique sequence of FmtRequests.
    per_language_results = await MultiGet(
        Get(
            _LanguageFmtResults,
            _LanguageFmtRequest(fmt_requests, Targets(target_batch)),
        ) for fmt_requests, targets in targets_by_fmt_request_order.items()
        for target_batch in partition_sequentially(
            targets,
            key=lambda t: t.address.spec,
            size_target=fmt_subsystem.batch_size,
            size_max=4 * fmt_subsystem.batch_size,
        ))

    individual_results = list(
        itertools.chain.from_iterable(
            language_result.results
            for language_result in per_language_results))

    if not individual_results:
        return Fmt(exit_code=0)

    changed_digests = tuple(language_result.output
                            for language_result in per_language_results
                            if language_result.did_change)
    if changed_digests:
        # NB: this will fail if there are any conflicting changes, which we want to happen rather
        # than silently having one result override the other. In practice, this should never
        # happen due to us grouping each language's formatters into a single digest.
        merged_formatted_digest = await Get(Digest,
                                            MergeDigests(changed_digests))
        workspace.write_digest(merged_formatted_digest)

    if individual_results:
        console.print_stderr("")

    # We group all results for the same formatter so that we can give one final status in the
    # summary. This is only relevant if there were multiple results because of
    # `--per-file-caching`.
    formatter_to_results = defaultdict(set)
    for result in individual_results:
        formatter_to_results[result.formatter_name].add(result)

    for formatter, results in sorted(formatter_to_results.items()):
        if any(result.did_change for result in results):
            sigil = console.sigil_succeeded_with_edits()
            status = "made changes"
        elif all(result.skipped for result in results):
            continue
        else:
            sigil = console.sigil_succeeded()
            status = "made no changes"
        console.print_stderr(f"{sigil} {formatter} {status}.")

    # Since the rules to produce FmtResult should use ExecuteRequest, rather than
    # FallibleProcess, we assume that there were no failures.
    return Fmt(exit_code=0)