コード例 #1
0
ファイル: test_fx.py プロジェクト: yangsenius/pytorch
 def test_get_all_users_of(self):
     graph: torch.fx.Graph = torch.fx.Graph()
     a: torch.fx.Node = graph.create_node('placeholder', 'x')
     b: torch.fx.Node = graph.create_node('call_module',
                                          'linear_mod',
                                          args=(a, ))
     c: torch.fx.Node = graph.create_node('get_attr', 'y_attr')
     d: torch.fx.Node = graph.create_node('call_function',
                                          operator.add,
                                          args=(b, c))
     graph.output(d)
     linear_mod: torch.nn.Module = torch.nn.Linear(3, 4)
     add_param: torch.Tensor = torch.rand(3, 4)
     gm: torch.fx.GraphModule = torch.fx.GraphModule(
         {
             'linear_mod': linear_mod,
             'y_attr': add_param
         }, graph)
     expected_uses: Dict[int, List[int]] = {
         0: [1],
         1: [3],
         2: [3],
         3: [4],
         4: [],
     }
     for i, node in enumerate(graph.nodes):
         user_indexes = GraphManipulation.get_all_users_of(gm, i)
         assert user_indexes == expected_uses[i]
コード例 #2
0
 def get_output_nodes(self) -> List[Node]:
     """Output nodes are the nodes that without any user inside this partition."""
     output_nodes: List[Node] = []
     for node in self.nodes:
         index = self.graph_module.graph.nodes.index(node)
         user_indexes = GraphManipulation.get_all_users_of(
             self.graph_module, index)
         user_nodes = {
             self.graph_module.graph.nodes[i]
             for i in user_indexes
         }
         # check if user nodes has an intersection with self.nodes
         if not set(self.nodes).intersection(user_nodes):
             output_nodes.append(node)
     return output_nodes