def test_get_all_users_of(self): graph: torch.fx.Graph = torch.fx.Graph() a: torch.fx.Node = graph.create_node('placeholder', 'x') b: torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a, )) c: torch.fx.Node = graph.create_node('get_attr', 'y_attr') d: torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) graph.output(d) linear_mod: torch.nn.Module = torch.nn.Linear(3, 4) add_param: torch.Tensor = torch.rand(3, 4) gm: torch.fx.GraphModule = torch.fx.GraphModule( { 'linear_mod': linear_mod, 'y_attr': add_param }, graph) expected_uses: Dict[int, List[int]] = { 0: [1], 1: [3], 2: [3], 3: [4], 4: [], } for i, node in enumerate(graph.nodes): user_indexes = GraphManipulation.get_all_users_of(gm, i) assert user_indexes == expected_uses[i]
def get_output_nodes(self) -> List[Node]: """Output nodes are the nodes that without any user inside this partition.""" output_nodes: List[Node] = [] for node in self.nodes: index = self.graph_module.graph.nodes.index(node) user_indexes = GraphManipulation.get_all_users_of( self.graph_module, index) user_nodes = { self.graph_module.graph.nodes[i] for i in user_indexes } # check if user nodes has an intersection with self.nodes if not set(self.nodes).intersection(user_nodes): output_nodes.append(node) return output_nodes