Beispiel #1
0
def _transform_remove_duplicates(module: GraphModule, debug: bool) -> GraphModule:
    """Removes duplicate modules by creating a copy of the module.

    This is necessary because BackPACK saves input/output which is overwritten
    if the module is called multiple times.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module

    Raises:
        NotImplementedError: if a duplicate module has parameters
    """
    if debug:
        print("\tBegin transformation: remove duplicates")

    graph: Graph = BackpackTracer().trace(module)

    targets = [n.target for n in graph.nodes]
    duplicates = {t for t in targets if targets.count(t) > 1}
    nodes = [n for n in graph.nodes if n.target in duplicates]

    for node in nodes:
        target = node.target
        original_module = module.get_submodule(target)

        for _ in original_module.parameters():
            raise NotImplementedError(
                f"Cycle with parameters detected: module {original_module} with target"
                f" {target} has parameters and is used {targets.count(target)} times."
            )

        new_module = deepcopy(original_module)
        new_target = _get_free_name(module, target)
        module.add_submodule(new_target, new_module)
        node.target = new_target

    graph.lint()

    if debug:
        print(f"\tDuplicates removed: {len(nodes)}")

    return GraphModule(module, graph)
Beispiel #2
0
def remove_duplicate_output_args(
    top_level: fx.GraphModule,
    target_subnets: t.Collection[str]
) -> t.Mapping[str, "RemoveDuplicateResult"]:
    """Removes duplicate output args.

    This pass removes duplicate output args from the target subnets and fixes
    their uses in the top level module where the subnets are called. This pass
    must be called after acc split on the top-level net and subsequent calls to
    the acc trace on the subnets.

    This pass will change both the subnets and top level module.

    Returns:
        a mapping of the target subnet name to its dedupcate result
    """

    processed_subnets = {}
    for node in top_level.graph.nodes:  # type: fx.Node
        if node.op == "call_module" and node.name in target_subnets:
            assert isinstance(node.target, str)
            sub_gm = top_level.get_submodule(node.target)
            assert isinstance(sub_gm, fx.GraphModule)

            replace_res = _remove_duplicate_output_args(sub_gm)
            processed_subnets[node.name] = replace_res
            if replace_res.replacement_map is None:
                continue
            sub_gm.recompile()

            needs_recompile = False
            # iterate on the copy since we will be changing elements of node.users
            for user in list(node.users):
                idx = _ensure_proper_output_use(user, node)
                idx_new = replace_res.replacement_map[idx]
                if idx_new != idx:
                    user.args = (user.args[0], idx_new)
                    needs_recompile = True

            if needs_recompile:
                top_level.recompile()
    return processed_subnets