def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_set_glow_backend(): """Test setting the Glow backend type""" backend_name_before = torch_glow.getGlowBackendName() backend_num_devices_before = torch_glow.getGlowBackendNumDevices() torch_glow.setGlowBackend("CPU", 4) assert (torch_glow.getGlowBackendName() == "CPU") assert (torch_glow.getGlowBackendNumDevices() == 4) # reset everything torch_glow.setGlowBackend(backend_name_before, backend_num_devices_before)
def test_quantized_cut(self): """Test cut quantized chunk in the middle.""" torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) def fun(a, b, c, d): q = torch.nn.quantized.Quantize(scale=1.0 / 21, zero_point=0, dtype=torch.quint8) dq = torch.nn.quantized.DeQuantize() a = q(a) b = q(b) c = q(c) d = q(d) adds = torch.ops.quantized.add(a, b, scale=1.0 / 17, zero_point=5) adds2 = torch.ops.quantized.add(c, d, scale=1.0 / 14, zero_point=4) res = torch.ops.quantized.add_relu(adds, adds2, scale=1.0 / 18, zero_point=6) res = torch.ops.quantized.add(res, res, scale=1.0 / 13, zero_point=7) res = dq(res) return res with torch.no_grad(): a = torch.randn([5, 5]) b = torch.randn([5, 5]) c = torch.randn([5, 5]) d = torch.randn([5, 5]) res_torch = fun(a, b, c, d) torch_glow.enableFusionPass() # Cut using blacklist functionality blacklist = ["quantized::add_relu"] torch_glow.setFusionBlacklist(blacklist) torch_glow.setGlowBackend("Interpreter") traced_model = torch.jit.trace(fun, (a, b, c, d)) for node in traced_model.graph_for(a, b, c, d).nodes(): kind = node.kind() # Make sure the blacklist is working assert (kind == GLOW_FUSION_GROUP or kind in blacklist or kind == "prim::Constant") res_glow = traced_model(a, b, c, d) print(res_torch) print(res_glow) assert torch.allclose(res_torch, res_glow)
def traceVsGlow( f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def pytest_sessionstart(session): backend = session.config.getoption("--backend") if backend: torch_glow.setGlowBackend(backend)