Ejemplo n.º 1
0
    def test_fuser_util_xfail(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name: node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        with self.assertRaises(Exception):
            fuse_by_partitions(gm, partitions)
Ejemplo n.º 2
0
    def test_fuser_util(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name: node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        fused_graph = fuse_by_partitions(gm, partitions)

        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)

        expected = m(a, b, c)
        result = fused_graph(a, b, c)

        torch.testing.assert_close(expected, result)
Ejemplo n.º 3
0
 def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
     logging.debug("Fusing partitions...")
     # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
     return fuse_by_partitions(
         self.graph_module,
         [list(partition.nodes) for partition in partitions])