def test_determine_specified_tool_names() -> None: class StyleReq(StyleRequest): name = "my-tool" with pytest.raises(ValueError) as exc: determine_specified_tool_names( "fake-goal", only_option=["bad"], all_style_requests=[StyleReq], extra_valid_names=["extra-tool"], ) assert ("Unrecognized name with the option `--fake-goal-only`: 'bad'\n\n" "All valid names: ['extra-tool', 'my-tool']") in str(exc.value)
async def check( console: Console, workspace: Workspace, targets: FilteredTargets, dist_dir: DistDir, union_membership: UnionMembership, check_subsystem: CheckSubsystem, ) -> Check: request_types = cast("Iterable[type[StyleRequest]]", union_membership[CheckRequest]) specified_names = determine_specified_tool_names("check", check_subsystem.only, request_types) requests = tuple( request_type( request_type.field_set_type.create(target) for target in targets if ( request_type.name in specified_names and request_type.field_set_type.is_applicable(target) ) ) for request_type in request_types ) all_results = await MultiGet( Get(CheckResults, CheckRequest, request) for request in requests if request.field_sets ) def get_name(res: CheckResults) -> str: return res.checker_name write_reports( all_results, workspace, dist_dir, goal_name=CheckSubsystem.name, get_name=get_name, ) exit_code = 0 if all_results: console.print_stderr("") for results in sorted(all_results, key=lambda results: results.checker_name): if results.skipped: continue elif results.exit_code == 0: sigil = console.sigil_succeeded() status = "succeeded" else: sigil = console.sigil_failed() status = "failed" exit_code = results.exit_code console.print_stderr(f"{sigil} {results.checker_name} {status}.") return Check(exit_code)
async def lint( console: Console, workspace: Workspace, specs: Specs, lint_subsystem: LintSubsystem, union_membership: UnionMembership, dist_dir: DistDir, ) -> Lint: lint_target_request_types = cast( "Iterable[type[LintTargetsRequest]]", union_membership.get(LintTargetsRequest) ) fmt_target_request_types = cast("Iterable[type[FmtRequest]]", union_membership.get(FmtRequest)) file_request_types = cast( "Iterable[type[LintFilesRequest]]", union_membership[LintFilesRequest] ) _check_ambiguous_request_names( *lint_target_request_types, *fmt_target_request_types, *file_request_types ) specified_names = determine_specified_tool_names( "lint", lint_subsystem.only, [*lint_target_request_types, *fmt_target_request_types], extra_valid_names={request.name for request in file_request_types}, ) def is_specified(request_type: type[StyleRequest] | type[LintFilesRequest]): return request_type.name in specified_names lint_target_request_types = filter(is_specified, lint_target_request_types) fmt_target_request_types = filter(is_specified, fmt_target_request_types) file_request_types = filter(is_specified, file_request_types) _get_targets = Get( FilteredTargets, Specs, specs if lint_target_request_types or fmt_target_request_types else Specs.empty(), ) _get_specs_paths = Get(SpecsPaths, Specs, specs if file_request_types else Specs.empty()) targets, specs_paths = await MultiGet(_get_targets, _get_specs_paths) def batch(field_sets: Iterable[FieldSet]) -> Iterator[list[FieldSet]]: partitions = partition_sequentially( field_sets, key=lambda fs: fs.address.spec, size_target=lint_subsystem.batch_size, size_max=4 * lint_subsystem.batch_size, ) for partition in partitions: yield partition def batch_by_type( request_types: Iterable[type[_SR]], ) -> tuple[tuple[type[_SR], list[FieldSet]], ...]: return tuple( (request_type, field_set_batch) for request_type in request_types for field_set_batch in batch( request_type.field_set_type.create(target) for target in targets if request_type.field_set_type.is_applicable(target) ) ) lint_target_requests = ( request_type(batch) for request_type, batch in batch_by_type(lint_target_request_types) ) fmt_requests: Iterable[FmtRequest] = () if not lint_subsystem.skip_formatters: batched_fmt_request_pairs = batch_by_type(fmt_target_request_types) all_fmt_source_batches = await MultiGet( Get( SourceFiles, SourceFilesRequest( cast( SourcesField, getattr(field_set, "sources", getattr(field_set, "source", None)), ) for field_set in batch ), ) for _, batch in batched_fmt_request_pairs ) fmt_requests = ( request_type( batch, snapshot=source_files_snapshot.snapshot, ) for (request_type, batch), source_files_snapshot in zip( batched_fmt_request_pairs, all_fmt_source_batches ) ) file_requests = ( tuple(request_type(specs_paths.files) for request_type in file_request_types) if specs_paths.files else () ) all_requests = [ *(Get(LintResults, LintTargetsRequest, request) for request in lint_target_requests), *(Get(FmtResult, FmtRequest, request) for request in fmt_requests), *(Get(LintResults, LintFilesRequest, request) for request in file_requests), ] all_batch_results = cast( "tuple[LintResults | FmtResult, ...]", await MultiGet(all_requests), # type: ignore[arg-type] ) def key_fn(results: LintResults | FmtResult): if isinstance(results, FmtResult): return results.formatter_name return results.linter_name # NB: We must pre-sort the data for itertools.groupby() to work properly. sorted_all_batch_results = sorted(all_batch_results, key=key_fn) formatter_failed = False def coerce_to_lintresult(batch_results: LintResults | FmtResult) -> tuple[LintResult, ...]: if isinstance(batch_results, FmtResult): nonlocal formatter_failed formatter_failed = formatter_failed or batch_results.did_change return ( LintResult( 1 if batch_results.did_change else 0, batch_results.stdout, batch_results.stderr, ), ) return batch_results.results # We consolidate all results for each linter into a single `LintResults`. all_results = tuple( sorted( ( LintResults( itertools.chain.from_iterable( coerce_to_lintresult(batch_results) for batch_results in results ), linter_name=linter_name, ) for linter_name, results in itertools.groupby(sorted_all_batch_results, key=key_fn) ), key=key_fn, ) ) def get_name(res: LintResults) -> str: return res.linter_name write_reports( all_results, workspace, dist_dir, goal_name=LintSubsystem.name, get_name=get_name, ) _print_results( console, all_results, formatter_failed, ) return Lint(_get_error_code(all_results))
async def lint( console: Console, workspace: Workspace, targets: Targets, specs_snapshot: SpecsSnapshot, lint_subsystem: LintSubsystem, union_membership: UnionMembership, dist_dir: DistDir, ) -> Lint: target_request_types = cast("Iterable[type[LintTargetsRequest]]", union_membership[LintTargetsRequest]) file_request_types = union_membership[LintFilesRequest] specified_names = determine_specified_tool_names( "lint", lint_subsystem.only, target_request_types, extra_valid_names={request.name for request in file_request_types}, ) target_requests = tuple( request_type( request_type.field_set_type.create(target) for target in targets if (request_type.name in specified_names and request_type.field_set_type.is_applicable(target))) for request_type in target_request_types) file_requests = (tuple( request_type(specs_snapshot.snapshot.files) for request_type in file_request_types if request_type.name in specified_names) if specs_snapshot.snapshot.files else ()) def address_str(fs: FieldSet) -> str: return fs.address.spec all_requests = [ *(Get(LintResults, LintTargetsRequest, request.__class__(field_set_batch)) for request in target_requests if request.field_sets for field_set_batch in partition_sequentially( request.field_sets, key=address_str, size_target=lint_subsystem.batch_size, size_max=4 * lint_subsystem.batch_size, )), *(Get(LintResults, LintFilesRequest, request) for request in file_requests), ] all_batch_results = cast( "tuple[LintResults, ...]", await MultiGet(all_requests), # type: ignore[arg-type] ) def key_fn(results: LintResults): return results.linter_name # NB: We must pre-sort the data for itertools.groupby() to work properly. sorted_all_batch_results = sorted(all_batch_results, key=key_fn) # We consolidate all results for each linter into a single `LintResults`. all_results = tuple( sorted( (LintResults( itertools.chain.from_iterable( batch_results.results for batch_results in all_linter_results), linter_name=linter_name, ) for linter_name, all_linter_results in itertools.groupby( sorted_all_batch_results, key=key_fn)), key=key_fn, )) def get_name(res: LintResults) -> str: return res.linter_name write_reports( all_results, workspace, dist_dir, goal_name=LintSubsystem.name, get_name=get_name, ) exit_code = 0 if all_results: console.print_stderr("") for results in all_results: if results.skipped: continue elif results.exit_code == 0: sigil = console.sigil_succeeded() status = "succeeded" else: sigil = console.sigil_failed() status = "failed" exit_code = results.exit_code console.print_stderr(f"{sigil} {results.linter_name} {status}.") return Lint(exit_code)
async def fmt( console: Console, targets: Targets, fmt_subsystem: FmtSubsystem, workspace: Workspace, union_membership: UnionMembership, ) -> Fmt: request_types = union_membership[FmtRequest] specified_names = determine_specified_tool_names("fmt", fmt_subsystem.only, request_types) # Group targets by the sequence of FmtRequests that apply to them. targets_by_fmt_request_order = defaultdict(list) for target in targets: fmt_requests = [] for fmt_request in request_types: valid_name = fmt_request.name in specified_names if valid_name and fmt_request.field_set_type.is_applicable( target): # type: ignore[misc] fmt_requests.append(fmt_request) if fmt_requests: targets_by_fmt_request_order[tuple(fmt_requests)].append(target) # Spawn sequential formatting per unique sequence of FmtRequests. per_language_results = await MultiGet( Get( _LanguageFmtResults, _LanguageFmtRequest(fmt_requests, Targets(target_batch)), ) for fmt_requests, targets in targets_by_fmt_request_order.items() for target_batch in partition_sequentially( targets, key=lambda t: t.address.spec, size_target=fmt_subsystem.batch_size, size_max=4 * fmt_subsystem.batch_size, )) individual_results = list( itertools.chain.from_iterable( language_result.results for language_result in per_language_results)) if not individual_results: return Fmt(exit_code=0) changed_digests = tuple(language_result.output for language_result in per_language_results if language_result.did_change) if changed_digests: # NB: this will fail if there are any conflicting changes, which we want to happen rather # than silently having one result override the other. In practice, this should never # happen due to us grouping each language's formatters into a single digest. merged_formatted_digest = await Get(Digest, MergeDigests(changed_digests)) workspace.write_digest(merged_formatted_digest) if individual_results: console.print_stderr("") # We group all results for the same formatter so that we can give one final status in the # summary. This is only relevant if there were multiple results because of # `--per-file-caching`. formatter_to_results = defaultdict(set) for result in individual_results: formatter_to_results[result.formatter_name].add(result) for formatter, results in sorted(formatter_to_results.items()): if any(result.did_change for result in results): sigil = console.sigil_succeeded_with_edits() status = "made changes" elif all(result.skipped for result in results): continue else: sigil = console.sigil_succeeded() status = "made no changes" console.print_stderr(f"{sigil} {formatter} {status}.") # Since the rules to produce FmtResult should use ExecuteRequest, rather than # FallibleProcess, we assume that there were no failures. return Fmt(exit_code=0)