def test_only_tensor_outputs(self): """Test that Glow fuser only produces tensor outputs.""" def f(a, b): x = (a + b).size(0) c = a.reshape(x, -1) return a + c torch_glow.disableFusionPass() a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) # By creating a graph with an aten::size (supported) feeding into an # unsupported op (prim::ListConstruct), we see that even if an op is # supported, if it produces a non-tensor output to the fusion group it # would not be fused. torch_glow.glowCustomFuseDebug_( jit_f_graph, ["prim::Constant", "aten::add", "aten::size", "aten::reshape"]) fusion_nodes = 0 aten_sizes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 if node.kind() == "aten::size": aten_sizes += 1 assert ( fusion_nodes == 2 ), "Expected two fusion nodes to be split up with aten::size between them" assert aten_sizes == 1, "Expected aten::size not to be fused"
def test_max_fusion_merge_size_zero(self): """Test Glow fuser maximum fusion merge size mechanism set to zero.""" def f(a): return a * a * a * a * a * a torch_glow.disableFusionPass() # Set maximum fusion merge size to 0 so that there is # no limit to fusion torch_glow.setMaxFusionMergeSize(0) a = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: fusion_nodes += 1 assert fusion_nodes == 1, "Expected just one fusion group to be created" torch_glow.setMaxFusionMergeSize(0)
def test_max_fusion_merge_size(self): """Test Glow fuser maximum fusion merge size mechanism.""" def f(a): return a * a * a * a * a * a torch_glow.disableFusionPass() # Set maximum fusion merge size to 3 nodes so that the # graph will not fit into 1 node torch_glow.setMaxFusionMergeSize(3) a = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes > 1, "Expected more than one fusion group to be created" torch_glow.setMaxFusionMergeSize(0)
def test_fuse_parallel_branches_without_fusible_root(self): r"""Test GlowFuser fusing parallel branches without a common fusible root x = add(x, x) y = add(y, y) | | b1 = add(x, x) b2 = add(y, y) \ / res = TupleConstruct(b1, b2) This should be fused as glow::FusionGroup_0 | TupleConstruct """ def test_fuser(x, y): x = x + x y = y + y branch1 = x + x branch2 = y + y res = (branch1, branch2) return res inputs = (torch.randn(2, 4), torch.randn(2, 4)) traced = torch.jit.trace(test_fuser, inputs) torch_glow.glowCustomFuseDebug_(traced.graph) count = 0 for node in traced.graph.nodes(): if node.kind() == "glow::FusionGroup": count += 1 assert count == 1, f"Expect 1 glow::FusionGroup, found {count}."
def test_min_graph_size(): """Test Glow fuser minimum fusion group size mechanism.""" torch_glow.disableFusionPass() # Disable aten::div so that each group of aten::mul nodes will be forced # into separate subgraphs torch_glow.setFusionBlacklist(["aten::div"]) # Set minimum fusion group size to 3 nodes so that the smallest group which # contains only 2 nodes will not be created torch_glow.setMinFusionGroupSize(3) a = torch.randn(5, 5) b = torch.randn(5, 5) c = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b, c)) jit_f_graph = jit_f.graph_for(a, b, c) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes == 2, "Expected smallest fusion group to not be created" torch_glow.clearFusionBlacklist() torch_glow.setMinFusionGroupSize(0)
def test_fuse_necessary_getattrs_only(): m = Model() x = torch.randn(1, 3, 5, 5) torch_glow.disableFusionPass() jit_m = torch.jit.trace(m, x) jit_m_graph = jit_m.graph_for(x) # don't fuse aten::_convolutions torch_glow.glowCustomFuseDebug_( jit_m_graph, ["prim::Constant", "prim::GetAttr", "aten::t", "aten::matmul", "aten::add_"], ) return m(x)