def traceVsGlow(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_op_blacklist_allowlist(self): """Test Glow fuser allowlist overwrites blacklist mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(["aten::add", "aten::sub"]) torch_glow.setFusionOverrideAllowlist(["aten::sub"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist() torch_glow.clearFusionOverrideAllowlist()
def test_getattr(self): """Test fusion of the PyTorch prim::GetAttr Node into the Glow subgraph.""" with torch.no_grad(): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear = torch.nn.Linear(2, 1) def forward(self, x): return self.linear(x) x = torch.tensor([2.0, 3.0]) torch_glow.enableFusionPass() m = Model() jit_m = torch.jit.trace(m, x) jit_m_graph = jit_m.graph_for(x) # Ensure all prim::GetAttrs were fused and none were left out found_getattrs = False for node in jit_m_graph.nodes(): kind = node.kind() assert ( kind != "prim::GetAttr" ), "Expected all prim::GetAttrsGlow to be in Glow subgraph" if kind == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "prim::GetAttr": found_getattrs = True assert (found_getattrs ), "Expected to find prim::GetAttrs in the Glow subgraph"
def jitVsGlow_(f_torch, f_glow, *inputs, expected_fused_ops=None, accept_all_ops=False): with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() glow_trace = torch.jit.trace(f_glow, inputs) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) expected_fused_ops_seen = set() # Check that ops that were *not* fused are *not* in expected_fused_ops for node in glow_graph.nodes(): kind = node.kind() if kind != GLOW_NODE_NAME: # If the node is not a Glow fusion group, check that it is # *not* in expected_fused_ops assert accept_all_ops or kind not in expected_fused_ops, \ "Expected {} to be fused".format(kind) else: # If the node is a Glow fusion group, record which ops from # expected_fused_ops were in it # Get the definition of the fusion group glow_group = node.g(SUBGRAPH_ATTR) # Put all nodes that are in the group and in expected_fused_ops # into expected_fused_ops_seen for fused_node in glow_group.nodes(): fused_node_kind = fused_node.kind() if accept_all_ops or fused_node_kind in expected_fused_ops: expected_fused_ops_seen.add(fused_node_kind) # If the sizes of expected_fused_ops and expected_fused_ops_seen are # different, some ops in expected_fused_ops are not in the graph at all assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \ "Expected all of expected_fused_ops to be in the graph" assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): assert torch.allclose(torch_res[i], glow_res[i], atol=01e-6)
def run_model(model, image, use_glow, print_graph): if use_glow: torch_glow.enableFusionPass() with torch.no_grad(): traced = torch.jit.trace(model, image) if print_graph: print(traced.graph_for(image)) all_outputs = traced(image) topk = all_outputs.topk(5) return (topk[1], topk[0])
def test_print_jit_indices(self): def test_f(a, b): c = a.add(b) return c.add(c) x = torch.randn(4) y = torch.randn(4) torch_glow.enableFusionPass() torch_glow.enable_printing_jit_node_indices() graph = torch.jit.trace(test_f, (x, y), check_trace=False) graph(x, y)
def test_quantized_cut(self): """Test cut quantized chunk in the middle.""" torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) def fun(a, b, c, d): q = torch.nn.quantized.Quantize(scale=1.0 / 128, zero_point=0, dtype=torch.quint8) dq = torch.nn.quantized.DeQuantize() a = q(a) b = q(b) c = q(c) d = q(d) adds = torch.ops.quantized.add(a, b, scale=1.0 / 121, zero_point=5) adds2 = torch.ops.quantized.add(c, d, scale=1.0 / 122, zero_point=4) res = torch.ops.quantized.add_relu(adds, adds2, scale=1.0 / 120, zero_point=6) res = torch.ops.quantized.add(res, res, scale=1.0 / 128, zero_point=7) res = dq(res) return res with torch.no_grad(): a = torch.randn([5, 5]) b = torch.randn([5, 5]) c = torch.randn([5, 5]) d = torch.randn([5, 5]) res_torch = fun(a, b, c, d) torch_glow.enableFusionPass() # Cut using blacklist functionality blacklist = ["quantized::add_relu"] torch_glow.setFusionBlacklist(blacklist) traced_model = torch.jit.trace(fun, (a, b, c, d)) for node in traced_model.graph_for(a, b, c, d).nodes(): kind = node.kind() # Make sure the blacklist is working assert (kind == GLOW_NODE_NAME or kind in blacklist or kind == "prim::Constant") res_glow = traced_model(a, b, c, d) print(res_torch) print(res_glow) assert torch.allclose(res_torch, res_glow)
def test_shape_inference_unsupported_symbols_skip_fusion_group(self): """Test Glow shape inference unsupported symbols including skipping of symbols after a secondary fusion group.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 * torch.chain_matmul(x8, x8) torch_glow.enableFusionPass() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() args = (a, b) # Don't skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, and should be reported. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=False ) expected = [ "aten::chain_matmul", ] self.assertEqual(set(expected), set(actual)) # DO skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, but is suppressed due to the skip_last_fusion_node flag. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=True ) expected = [] self.assertEqual(set(expected), set(actual))
def scriptVsGlow( f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def test_backend_specific_options(self): """Test loading backend specific options from YAML file.""" def test_f(a, b): return a.add(b) x = torch.randn(4) y = torch.randn(4) # Create YAML file with backend options with tempfile.NamedTemporaryFile() as options_fd: options_fd.write(b'interpreter-memory: 4194304\n') options_fd.flush() # Run Glow torch_glow.loadBackendSpecificOptions(options_fd.name) torch_glow.enableFusionPass() glow_trace = torch.jit.trace(test_f, (x, y), check_trace=False) glow_trace(x, y)
def test_op_index_blacklist_allowlist(self): """Test Glow fuser allowlist overwrites index blacklisting mechanism.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 torch_glow.enableFusionPass() # Only one div is allowed by index torch_glow.setFusionStartIndex(5) torch_glow.setFusionEndIndex(6) # But all divs are allowed by allowlist torch_glow.setFusionOverrideAllowlist(["aten::div"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() torch_glow.clearFusionOverrideAllowlist() fused_muls = 0 fused_divs = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::mul": fused_muls += 1 if node.kind() == "aten::div": fused_divs += 1 assert fused_muls == 0, "Expected no aten::muls to be fused" assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
def test_op_index_blacklist(self): """Test Glow fuser index blacklisting mechanism.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 torch_glow.enableFusionPass() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() fused_muls = 0 fused_divs = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::mul": fused_muls += 1 if node.kind() == "aten::div": fused_divs += 1 assert fused_muls == 0, "Expected no aten::muls to be fused" assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
def scriptVsGlow(f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def jitVsGlow(f, *inputs): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f, inputs) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() glow_trace = torch.jit.trace(f, inputs) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) # check that there is exactly 1 Glow node in the glow graph glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) num_glow_nodes = len(glow_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 1, "Expected exactly 1 Glow node, found {}".format( num_glow_nodes) assert torch.allclose(torch_res, glow_res, atol=01e-6)
import torch import torch.nn as nn import torch_glow class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.linear = nn.Linear(10, 2) def forward(self, x): return self.linear(x) torch._C._jit_set_profiling_mode(True) torch_glow.enableFusionPass() m = Model() m_jit = torch.jit.script(m) x = torch.randn(10) # No Glow fusion node print("initial jit ir") print(m_jit.graph_for(x)) m_jit(x) m_jit(x) m_jit(x)
def traceVsGlow( f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def jitVsGlow_(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if (black_list is None): black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) expected_fused_ops_seen = set() # Whether or not at least one node was fused to Glow. nodes_were_fused = False # Check that ops that were *not* fused are *not* in expected_fused_ops for node in glow_graph.nodes(): kind = node.kind() if kind != GLOW_NODE_NAME: # If the node is not a Glow fusion group, check that it is # *not* in expected_fused_ops assert accept_all_ops or kind not in expected_fused_ops, \ "Expected {} to be fused".format(kind) else: # If the node is a Glow fusion group, record which ops from # expected_fused_ops were in it # Get the definition of the fusion group glow_group = node.g(SUBGRAPH_ATTR) # Put all nodes that are in the group and in expected_fused_ops # into expected_fused_ops_seen for fused_node in glow_group.nodes(): nodes_were_fused = True fused_node_kind = fused_node.kind() if accept_all_ops or fused_node_kind in expected_fused_ops: expected_fused_ops_seen.add(fused_node_kind) assert nodes_were_fused, "Expected some nodes to be fused to Glow" # If the sizes of expected_fused_ops and expected_fused_ops_seen are # different, some ops in expected_fused_ops are not in the graph at all assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \ "Expected all of expected_fused_ops to be in the graph" if isinstance(torch_res, tuple) or isinstance(glow_res, tuple): assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple) assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): print("torch shape: {}".format(torch_res[i].shape), file=sys.stderr) print("glow shape: {}".format(glow_res[i].shape), file=sys.stderr) assert torch.allclose(torch_res[i], glow_res[i], atol=atol, rtol=rtol) else: print("torch shape: {}".format(torch_res.shape), file=sys.stderr) print("glow shape: {}".format(glow_res.shape), file=sys.stderr) is_all_close = torch.allclose(torch_res, glow_res, atol=atol, rtol=rtol) if not is_all_close: print("torch_res\n", torch_res) print("glow_res\n", glow_res) print("diff\n", torch.abs(glow_res - torch_res)) assert is_all_close