def test_op_blacklist(self): """Test Glow fuser op kind blacklisting mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionBlacklist(["aten::add"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist()
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_op_blacklist_allowlist(self): """Test Glow fuser allowlist overwrites blacklist mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(["aten::add", "aten::sub"]) torch_glow.setFusionOverrideAllowlist(["aten::sub"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist() torch_glow.clearFusionOverrideAllowlist()
def test_min_graph_size(): """Test Glow fuser minimum fusion group size mechanism.""" torch_glow.disableFusionPass() # Disable aten::div so that each group of aten::mul nodes will be forced # into separate subgraphs torch_glow.setFusionBlacklist(["aten::div"]) # Set minimum fusion group size to 3 nodes so that the smallest group which # contains only 2 nodes will not be created torch_glow.setMinFusionGroupSize(3) a = torch.randn(5, 5) b = torch.randn(5, 5) c = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b, c)) jit_f_graph = jit_f.graph_for(a, b, c) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes == 2, "Expected smallest fusion group to not be created" torch_glow.clearFusionBlacklist() torch_glow.setMinFusionGroupSize(0)