async def fmt( console: Console, targets: Targets, fmt_subsystem: FmtSubsystem, workspace: Workspace, union_membership: UnionMembership, ) -> Fmt: language_target_collection_types = union_membership[LanguageFmtTargets] language_target_collections = tuple( language_target_collection_type( Targets(target for target in targets if language_target_collection_type.belongs_to_language( target))) for language_target_collection_type in language_target_collection_types) targets_with_sources = await MultiGet( Get( TargetsWithSources, TargetsWithSourcesRequest(language_target_collection.targets), ) for language_target_collection in language_target_collections) # NB: We must convert back the generic TargetsWithSources objects back into their # corresponding LanguageFmtTargets, e.g. back to PythonFmtTargets, in order for the union # rule to work. valid_language_target_collections = tuple( language_target_collection_cls( Targets(target for target in language_target_collection.targets if target in language_targets_with_sources)) for language_target_collection_cls, language_target_collection, language_targets_with_sources in zip( language_target_collection_types, language_target_collections, targets_with_sources) if language_targets_with_sources) if fmt_subsystem.per_file_caching: per_language_results = await MultiGet( Get( LanguageFmtResults, LanguageFmtTargets, language_target_collection.__class__(Targets([target])), ) for language_target_collection in valid_language_target_collections for target in language_target_collection.targets) else: per_language_results = await MultiGet( Get(LanguageFmtResults, LanguageFmtTargets, language_target_collection) for language_target_collection in valid_language_target_collections ) individual_results = list( itertools.chain.from_iterable( language_result.results for language_result in per_language_results)) if not individual_results: return Fmt(exit_code=0) changed_digests = tuple(language_result.output for language_result in per_language_results if language_result.did_change) if changed_digests: # NB: this will fail if there are any conflicting changes, which we want to happen rather # than silently having one result override the other. In practice, this should never # happen due to us grouping each language's formatters into a single digest. merged_formatted_digest = await Get(Digest, MergeDigests(changed_digests)) workspace.write_digest(merged_formatted_digest) if individual_results: console.print_stderr("") # We group all results for the same formatter so that we can give one final status in the # summary. This is only relevant if there were multiple results because of # `--per-file-caching`. formatter_to_results = defaultdict(set) for result in individual_results: formatter_to_results[result.formatter_name].add(result) for formatter, results in sorted(formatter_to_results.items()): if any(result.did_change for result in results): sigil = console.sigil_succeeded_with_edits() status = "made changes" elif all(result.skipped for result in results): sigil = console.sigil_skipped() status = "skipped" else: sigil = console.sigil_succeeded() status = "made no changes" console.print_stderr(f"{sigil} {formatter} {status}.") # Since the rules to produce FmtResult should use ExecuteRequest, rather than # FallibleProcess, we assume that there were no failures. return Fmt(exit_code=0)
async def fmt( console: Console, targets: Targets, fmt_subsystem: FmtSubsystem, workspace: Workspace, union_membership: UnionMembership, ) -> Fmt: request_types = union_membership[FmtRequest] specified_names = determine_specified_tool_names("fmt", fmt_subsystem.only, request_types) # Group targets by the sequence of FmtRequests that apply to them. targets_by_fmt_request_order = defaultdict(list) for target in targets: fmt_requests = [] for fmt_request in request_types: valid_name = fmt_request.name in specified_names if valid_name and fmt_request.field_set_type.is_applicable( target): # type: ignore[misc] fmt_requests.append(fmt_request) if fmt_requests: targets_by_fmt_request_order[tuple(fmt_requests)].append(target) # Spawn sequential formatting per unique sequence of FmtRequests. per_language_results = await MultiGet( Get( _LanguageFmtResults, _LanguageFmtRequest(fmt_requests, Targets(target_batch)), ) for fmt_requests, targets in targets_by_fmt_request_order.items() for target_batch in partition_sequentially( targets, key=lambda t: t.address.spec, size_target=fmt_subsystem.batch_size, size_max=4 * fmt_subsystem.batch_size, )) individual_results = list( itertools.chain.from_iterable( language_result.results for language_result in per_language_results)) if not individual_results: return Fmt(exit_code=0) changed_digests = tuple(language_result.output for language_result in per_language_results if language_result.did_change) if changed_digests: # NB: this will fail if there are any conflicting changes, which we want to happen rather # than silently having one result override the other. In practice, this should never # happen due to us grouping each language's formatters into a single digest. merged_formatted_digest = await Get(Digest, MergeDigests(changed_digests)) workspace.write_digest(merged_formatted_digest) if individual_results: console.print_stderr("") # We group all results for the same formatter so that we can give one final status in the # summary. This is only relevant if there were multiple results because of # `--per-file-caching`. formatter_to_results = defaultdict(set) for result in individual_results: formatter_to_results[result.formatter_name].add(result) for formatter, results in sorted(formatter_to_results.items()): if any(result.did_change for result in results): sigil = console.sigil_succeeded_with_edits() status = "made changes" elif all(result.skipped for result in results): continue else: sigil = console.sigil_succeeded() status = "made no changes" console.print_stderr(f"{sigil} {formatter} {status}.") # Since the rules to produce FmtResult should use ExecuteRequest, rather than # FallibleProcess, we assume that there were no failures. return Fmt(exit_code=0)