def test_getattr(self): """Test fusion of the PyTorch prim::GetAttr Node into the Glow subgraph.""" with torch.no_grad(): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear = torch.nn.Linear(2, 1) def forward(self, x): return self.linear(x) x = torch.tensor([2.0, 3.0]) torch_glow.enableFusionPass_DO_NOT_USE_THIS() m = Model() jit_m = torch.jit.trace(m, x) jit_m_graph = jit_m.graph_for(x) # Ensure all prim::GetAttrs were fused and none were left out found_getattrs = False for node in jit_m_graph.nodes(): kind = node.kind() assert ( kind != "prim::GetAttr" ), "Expected all prim::GetAttrsGlow to be in Glow subgraph" if kind == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "prim::GetAttr": found_getattrs = True assert (found_getattrs ), "Expected to find prim::GetAttrs in the Glow subgraph"
def test_op_blacklist(self): """Test Glow fuser op kind blacklisting mechanism.""" def f(a, b): return (a + b) * (a - b) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionBlacklist(["aten::add"]) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) fused_add = False fused_sub = False for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::add": fused_add = True if node.kind() == "aten::sub": fused_sub = True assert not fused_add, "Expected aten::add to be blacklisted" assert fused_sub, "Expected aten::sub to not be blacklisted" torch_glow.clearFusionBlacklist()
def run_model(model, image, use_glow, backend, print_graph): if use_glow: torch_glow.enableFusionPass_DO_NOT_USE_THIS() if backend: torch_glow.setGlowBackend(backend) with torch.no_grad(): traced = torch.jit.trace(model, image) if print_graph: print(traced.graph_for(image)) all_outputs = traced(image) topk = all_outputs.topk(5) return (topk[1], topk[0])
def test_print_jit_indices(self): def test_f(a, b): c = a.add(b) return c.add(c) x = torch.randn(4) y = torch.randn(4) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.enable_printing_jit_node_indices() graph = torch.jit.trace(test_f, (x, y), check_trace=False) graph(x, y)
def test_shape_inference_unsupported_symbols_skip_fusion_group(self): """Test Glow shape inference unsupported symbols including skipping of symbols after a secondary fusion group.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 * torch.chain_matmul(x8, x8) torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() args = (a, b) # Don't skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, and should be reported. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=False ) expected = [ "aten::chain_matmul", ] self.assertEqual(set(expected), set(actual)) # DO skip nodes after the last fusion node. # in this case, one of the nodes (chain_matmul) following the last fusion node # is not supported, but is suppressed due to the skip_last_fusion_node flag. actual = torch_glow.glow_shape_inference_find_unsupported_symbols( jit_f_graph, args, skip_last_fusion_node=True ) expected = [] self.assertEqual(set(expected), set(actual))
def test_quantized_cut(self): """Test cut quantized chunk in the middle.""" torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) def fun(a, b, c, d): q = torch.nn.quantized.Quantize(scale=1.0 / 21, zero_point=0, dtype=torch.quint8) dq = torch.nn.quantized.DeQuantize() a = q(a) b = q(b) c = q(c) d = q(d) adds = torch.ops.quantized.add(a, b, scale=1.0 / 17, zero_point=5) adds2 = torch.ops.quantized.add(c, d, scale=1.0 / 14, zero_point=4) res = torch.ops.quantized.add_relu(adds, adds2, scale=1.0 / 18, zero_point=6) res = torch.ops.quantized.add(res, res, scale=1.0 / 13, zero_point=7) res = dq(res) return res with torch.no_grad(): a = torch.randn([5, 5]) b = torch.randn([5, 5]) c = torch.randn([5, 5]) d = torch.randn([5, 5]) res_torch = fun(a, b, c, d) torch_glow.enableFusionPass_DO_NOT_USE_THIS() # Cut using blacklist functionality blacklist = ["quantized::add_relu"] torch_glow.setFusionBlacklist(blacklist) torch_glow.setGlowBackend("Interpreter") traced_model = torch.jit.trace(fun, (a, b, c, d)) for node in traced_model.graph_for(a, b, c, d).nodes(): kind = node.kind() # Make sure the blacklist is working assert (kind == GLOW_FUSION_GROUP or kind in blacklist or kind == "prim::Constant") res_glow = traced_model(a, b, c, d) print(res_torch) print(res_glow) assert torch.allclose(res_torch, res_glow)
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None, accept_all_layouts=False, ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass_DO_NOT_USE_THIS() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) if accept_all_layouts: torch_glow.enable_accept_all_layout() else: torch_glow.disable_accept_all_layout() torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16( ) if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16( ) if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16( ) if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass_DO_NOT_USE_THIS( ) if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_backend_specific_options(self): """Test loading backend specific options from YAML file.""" def test_f(a, b): return a.add(b) x = torch.randn(4) y = torch.randn(4) # Create YAML file with backend options with tempfile.NamedTemporaryFile() as options_fd: options_fd.write(b"interpreter-memory: 4194304\n") options_fd.flush() # Run Glow torch_glow.loadBackendSpecificOptions(options_fd.name) torch_glow.enableFusionPass_DO_NOT_USE_THIS() glow_trace = torch.jit.trace(test_f, (x, y), check_trace=False) glow_trace(x, y)
def test_op_index_blacklist(self): """Test Glow fuser index blacklisting mechanism.""" def f(a, b): x1 = a * b x2 = x1 * b x3 = x2 * a x4 = x3 / b x5 = x4 / a x6 = x5 / b x7 = x6 * a x8 = x7 * b return x8 torch_glow.enableFusionPass_DO_NOT_USE_THIS() torch_glow.setFusionStartIndex(3) torch_glow.setFusionEndIndex(6) a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) torch_glow.clearFusionIndices() fused_muls = 0 fused_divs = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: glow_subgraph = node.g(SUBGRAPH_ATTR) for node in glow_subgraph.nodes(): if node.kind() == "aten::mul": fused_muls += 1 if node.kind() == "aten::div": fused_divs += 1 assert fused_muls == 0, "Expected no aten::muls to be fused" assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
@torch.jit.script def foo(a, b): c = a.mul(b) a = c.mul(c) a = c.mul(a) d = c.div(a) return d print("original jit ir") print(foo.graph_for(x, y)) jit_res = foo(x, y) torch_glow.enableFusionPass_DO_NOT_USE_THIS() @torch.jit.script def foo_glow(a, b): return foo(a, b) print("glow jit ir") print(foo_glow.graph_for(x, y)) jit_glow_res = foo_glow(x, y) print("jit_res") print(jit_res) print("jit_glow_res")