def test_fuser_util_xfail(self, partition): m = TestModule() gm = symbolic_trace(m) nodes_by_name = {node.name: node for node in gm.graph.nodes} partitions = [] for node_names in partition: partitions.append([nodes_by_name[name] for name in node_names]) with self.assertRaises(Exception): fuse_by_partitions(gm, partitions)
def test_fuser_util(self, partition): m = TestModule() gm = symbolic_trace(m) nodes_by_name = {node.name: node for node in gm.graph.nodes} partitions = [] for node_names in partition: partitions.append([nodes_by_name[name] for name in node_names]) fused_graph = fuse_by_partitions(gm, partitions) a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) expected = m(a, b, c) result = fused_graph(a, b, c) torch.testing.assert_close(expected, result)
def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: logging.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] return fuse_by_partitions( self.graph_module, [list(partition.nodes) for partition in partitions])