def test_shape_inference_unsupported_symbols_skip_fusion_group(self): """Test Glow shape inference unsupported symbols including skipping of symbols after a secondary fusion group.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 * torch.chain_matmul(x8, x8) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() args = (a, b) # Don't skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, and should be reported. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=False ) expected = [ "aten::chain_matmul", ] self.assertEqual(set(expected), set(actual)) # DO skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, but is suppressed due to the skip_last_fusion_node flag. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=True ) expected = [] self.assertEqual(set(expected), set(actual))
def test_op_index_blacklist_allowlist(self): """Test Glow fuser allowlist overwrites index blacklisting mechanism.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 torch_glow.enableFusionPass() # Only one div is allowed by index torch_glow.setFusionStartIndex(5) torch_glow.setFusionEndIndex(6) # But all divs are allowed by allowlist torch_glow.setFusionOverrideAllowlist(["aten::div"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() torch_glow.clearFusionOverrideAllowlist() fused_muls = 0 fused_divs = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::mul": fused_muls += 1 if node.kind() == "aten::div": fused_divs += 1 assert fused_muls == 0, "Expected no aten::muls to be fused" assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
def test_op_index_blacklist(self): """Test Glow fuser index blacklisting mechanism.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() fused_muls = 0 fused_divs = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::mul": fused_muls += 1 if node.kind() == "aten::div": fused_divs += 1 assert fused_muls == 0, "Expected no aten::muls to be fused" assert fused_divs == 3, "Expected all 3 aten::divs to be fused"