def test_op_blacklist(self): """Test Glow fuser op kind blacklisting mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionBlacklist(["aten::add"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist()
def traceVsGlow(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def test_op_blacklist_allowlist(self): """Test Glow fuser allowlist overwrites blacklist mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(["aten::add", "aten::sub"]) torch_glow.setFusionOverrideAllowlist(["aten::sub"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist() torch_glow.clearFusionOverrideAllowlist()
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_min_graph_size(): """Test Glow fuser minimum fusion group size mechanism.""" torch_glow.disableFusionPass() # Disable aten::div so that each group of aten::mul nodes will be forced # into separate subgraphs torch_glow.setFusionBlacklist(["aten::div"]) # Set minimum fusion group size to 3 nodes so that the smallest group which # contains only 2 nodes will not be created torch_glow.setMinFusionGroupSize(3) a = torch.randn(5, 5) b = torch.randn(5, 5) c = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b, c)) jit_f_graph = jit_f.graph_for(a, b, c) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes == 2, "Expected smallest fusion group to not be created" torch_glow.clearFusionBlacklist() torch_glow.setMinFusionGroupSize(0)
def test_quantized_cut(self): """Test cut quantized chunk in the middle.""" torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) def fun(a, b, c, d): q = torch.nn.quantized.Quantize(scale=1.0 / 128, zero_point=0, dtype=torch.quint8) dq = torch.nn.quantized.DeQuantize() a = q(a) b = q(b) c = q(c) d = q(d) adds = torch.ops.quantized.add(a, b, scale=1.0 / 121, zero_point=5) adds2 = torch.ops.quantized.add(c, d, scale=1.0 / 122, zero_point=4) res = torch.ops.quantized.add_relu(adds, adds2, scale=1.0 / 120, zero_point=6) res = torch.ops.quantized.add(res, res, scale=1.0 / 128, zero_point=7) res = dq(res) return res with torch.no_grad(): a = torch.randn([5, 5]) b = torch.randn([5, 5]) c = torch.randn([5, 5]) d = torch.randn([5, 5]) res_torch = fun(a, b, c, d) torch_glow.enableFusionPass() # Cut using blacklist functionality blacklist = ["quantized::add_relu"] torch_glow.setFusionBlacklist(blacklist) traced_model = torch.jit.trace(fun, (a, b, c, d)) for node in traced_model.graph_for(a, b, c, d).nodes(): kind = node.kind() # Make sure the blacklist is working assert (kind == GLOW_NODE_NAME or kind in blacklist or kind == "prim::Constant") res_glow = traced_model(a, b, c, d) print(res_torch) print(res_glow) assert torch.allclose(res_torch, res_glow)
def scriptVsGlow( f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def scriptVsGlow(f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def traceVsGlow( f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def jitVsGlow_(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if (black_list is None): black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) expected_fused_ops_seen = set() # Whether or not at least one node was fused to Glow. nodes_were_fused = False # Check that ops that were *not* fused are *not* in expected_fused_ops for node in glow_graph.nodes(): kind = node.kind() if kind != GLOW_NODE_NAME: # If the node is not a Glow fusion group, check that it is # *not* in expected_fused_ops assert accept_all_ops or kind not in expected_fused_ops, \ "Expected {} to be fused".format(kind) else: # If the node is a Glow fusion group, record which ops from # expected_fused_ops were in it # Get the definition of the fusion group glow_group = node.g(SUBGRAPH_ATTR) # Put all nodes that are in the group and in expected_fused_ops # into expected_fused_ops_seen for fused_node in glow_group.nodes(): nodes_were_fused = True fused_node_kind = fused_node.kind() if accept_all_ops or fused_node_kind in expected_fused_ops: expected_fused_ops_seen.add(fused_node_kind) assert nodes_were_fused, "Expected some nodes to be fused to Glow" # If the sizes of expected_fused_ops and expected_fused_ops_seen are # different, some ops in expected_fused_ops are not in the graph at all assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \ "Expected all of expected_fused_ops to be in the graph" if isinstance(torch_res, tuple) or isinstance(glow_res, tuple): assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple) assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): print("torch shape: {}".format(torch_res[i].shape), file=sys.stderr) print("glow shape: {}".format(glow_res[i].shape), file=sys.stderr) assert torch.allclose(torch_res[i], glow_res[i], atol=atol, rtol=rtol) else: print("torch shape: {}".format(torch_res.shape), file=sys.stderr) print("glow shape: {}".format(glow_res.shape), file=sys.stderr) is_all_close = torch.allclose(torch_res, glow_res, atol=atol, rtol=rtol) if not is_all_close: print("torch_res\n", torch_res) print("glow_res\n", glow_res) print("diff\n", torch.abs(glow_res - torch_res)) assert is_all_close