def legalize_graph(gm: torch.fx.GraphModule): """ Replace the graph of the given GraphModule with one that contains the same nodes as the original, but in topologically sorted order. This is used by the merge_matmul transformation below, which disturbs the topologically sorted order of its input GraphModule, so that this order is restored before further transformation. Arguments: gm: The graph module to topologically sort. It is modified in-place. """ # Build an adjacency list representation of node dependencies in the graph. This also # serves as a list of nodes that still need to be inserted into the new, topologically # sorted graph. dependencies = { node: node.all_input_nodes.copy() for node in gm.graph.nodes } # Construct a new graph that will contain all nodes in topologically sorted order. new_graph = torch.fx.Graph() value_remap: Dict[torch.fx.Node, torch.fx.Node] = {} # Copy over all nodes with no dependencies. for node, deps in dependencies.items(): if not deps: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) # Remove the copied over nodes from the adjacency list. for copied_node in value_remap.keys(): del dependencies[copied_node] # While there are still nodes to insert into the new graph: while dependencies: copied_this_round = [] # Copy over all nodes whose dependencies already exist in the new graph. for node, deps in dependencies.items(): all_deps_copied = True for dep in deps: if dep not in value_remap: all_deps_copied = False if all_deps_copied: value_remap[node] = new_graph.node_copy( node, lambda n: value_remap[n]) copied_this_round.append(node) # Delete all nodes copied over in this iteration from dependencies. for copied_node in copied_this_round: del dependencies[copied_node] # Replace the old graph with the new, topologically sorted one. gm.graph = new_graph
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Replace the graph of the given GraphModule with one that contains the same nodes as the original, but in topologically sorted order. This is used by the merge_matmul transformation below, which disturbs the topologically sorted order of its input GraphModule, so that this order is restored before further transformation. Arguments: gm: The graph module to topologically sort. It is modified in-place. Returns: The graph module in-place sorted """ indeg = {node: 0 for node in gm.graph.nodes} new_graph = torch.fx.Graph() # Track how many unfulfilled dependencies each node has for node in gm.graph.nodes: for user in node.users: indeg[user] += 1 queue: collections.deque = collections.deque() # Add all nodes with no dependencies to the queue for node in gm.graph.nodes: if indeg[node] == 0: queue.append(node) env: Dict[torch.fx.Node, torch.fx.Node] = {} # Pop nodes from the queue, and add nodes that have had all their # dependencies fulfilled while len(queue) > 0: cur = queue.popleft() env[cur] = new_graph.node_copy(cur, lambda x: env[x]) for user in cur.users: indeg[user] -= 1 if indeg[user] == 0: queue.append(user) # If the new graph's size is not as large as the old one, then there must be # a cycle (i.e. some node's dependencies were not satisfied.) if len(new_graph.nodes) < len(gm.graph.nodes): raise RuntimeError( f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" ) gm.graph = new_graph return gm