def _transform_remove_duplicates(module: GraphModule, debug: bool) -> GraphModule: """Removes duplicate modules by creating a copy of the module. This is necessary because BackPACK saves input/output which is overwritten if the module is called multiple times. Args: module: container module to transform debug: whether to print debug messages Returns: equivalent transformed module Raises: NotImplementedError: if a duplicate module has parameters """ if debug: print("\tBegin transformation: remove duplicates") graph: Graph = BackpackTracer().trace(module) targets = [n.target for n in graph.nodes] duplicates = {t for t in targets if targets.count(t) > 1} nodes = [n for n in graph.nodes if n.target in duplicates] for node in nodes: target = node.target original_module = module.get_submodule(target) for _ in original_module.parameters(): raise NotImplementedError( f"Cycle with parameters detected: module {original_module} with target" f" {target} has parameters and is used {targets.count(target)} times." ) new_module = deepcopy(original_module) new_target = _get_free_name(module, target) module.add_submodule(new_target, new_module) node.target = new_target graph.lint() if debug: print(f"\tDuplicates removed: {len(nodes)}") return GraphModule(module, graph)
def remove_duplicate_output_args( top_level: fx.GraphModule, target_subnets: t.Collection[str] ) -> t.Mapping[str, "RemoveDuplicateResult"]: """Removes duplicate output args. This pass removes duplicate output args from the target subnets and fixes their uses in the top level module where the subnets are called. This pass must be called after acc split on the top-level net and subsequent calls to the acc trace on the subnets. This pass will change both the subnets and top level module. Returns: a mapping of the target subnet name to its dedupcate result """ processed_subnets = {} for node in top_level.graph.nodes: # type: fx.Node if node.op == "call_module" and node.name in target_subnets: assert isinstance(node.target, str) sub_gm = top_level.get_submodule(node.target) assert isinstance(sub_gm, fx.GraphModule) replace_res = _remove_duplicate_output_args(sub_gm) processed_subnets[node.name] = replace_res if replace_res.replacement_map is None: continue sub_gm.recompile() needs_recompile = False # iterate on the copy since we will be changing elements of node.users for user in list(node.users): idx = _ensure_proper_output_use(user, node) idx_new = replace_res.replacement_map[idx] if idx_new != idx: user.args = (user.args[0], idx_new) needs_recompile = True if needs_recompile: top_level.recompile() return processed_subnets