def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None): modelName = model.__class__.__name__ core.disable_jit_opt() core.disable_mix_bf16_fp32() model = model.to(device).eval() x = x.to(device) with torch.no_grad(): result = model(x) script_model = torch.jit.script(model) script_model.eval() trace_model = torch.jit.trace(model, x) trace_model.eval() with torch.no_grad(): sresult = script_model(x) tresult = trace_model(x) self.assertEqual(result, sresult) self.assertEqual(result, tresult) core.enable_jit_opt() script_fused_model = torch.jit.script(model) trace_fused_model = torch.jit.trace(model, x) with torch.no_grad(): # conv relu fusion, conv sum fusion or conv sum relu fusion script_graph = script_fused_model.graph_for(x) fused_sresult = script_fused_model(x) trace_graph = trace_fused_model.graph_for(x) fused_tresult = trace_fused_model(x) self.assertEqual(result, fused_sresult) self.assertEqual(result, fused_tresult) # check if the fused node exists in the graph if kind_in_graph is not None: self.assertTrue( any(n.kind() == kind_in_graph for n in script_graph.nodes())) self.assertTrue( any(n.kind() == kind_in_graph for n in trace_graph.nodes())) # check if certain node does not exist in the graph if kind_not_in_graph is not None: self.assertTrue( all(n.kind() != kind_not_in_graph for n in script_graph.nodes())) self.assertTrue( all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
def _test_output_int8(self, model, x, kind_in_graph=None, kind_not_in_graph=None, prec=None): modelName = model.__class__.__name__ core.enable_auto_dnnl() core.enable_jit_opt() model = model.to(ipex.DEVICE).eval() x = x.to(ipex.DEVICE) x2 = x.clone() x3 = x.clone() script_fused_model = torch.jit.script(copy.deepcopy(model)) trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3) with torch.no_grad(): # int8, native path int8_calibration(model, [x], "configure.json") int8_conf = ipex.AmpConf(torch.int8, "configure.json") with ipex.AutoMixPrecision(int8_conf): result = model(x) # int8, jit script path script_graph = script_fused_model.graph_for(x2) int8_calibration(script_fused_model, [x2], "configure.json") int8_conf = ipex.AmpConf(torch.int8, "configure.json") with ipex.AutoMixPrecision(int8_conf): fused_sresult = script_fused_model(x2) # int8, jit trace path trace_graph = trace_fused_model.graph_for(x3) int8_calibration(trace_fused_model, [x3], "configure.json") int8_conf = ipex.AmpConf(torch.int8, "configure.json") with ipex.AutoMixPrecision(int8_conf): fused_tresult = trace_fused_model(x3) os.remove('configure.json') self.assertEqual(fused_sresult, result, prec) self.assertEqual(fused_tresult, result, prec) # check if the fused node exists in the graph if kind_in_graph is not None: self.assertTrue( any(n.kind() == kind_in_graph for n in script_graph.nodes())) self.assertTrue( any(n.kind() == kind_in_graph for n in trace_graph.nodes())) # check if certain node does not exist in the graph if kind_not_in_graph is not None: self.assertTrue( all(n.kind() != kind_not_in_graph for n in script_graph.nodes())) self.assertTrue( all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None, prec=None): modelName = model.__class__.__name__ core.enable_auto_dnnl() core.enable_jit_opt() core.enable_mix_bf16_fp32() model = model.to(ipex.DEVICE).eval() x = x.to(ipex.DEVICE) x2 = x.clone() x3 = x.clone() script_fused_model = torch.jit.script(copy.deepcopy(model)) trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3) with torch.no_grad(): # bf16, native path result = model(x) # bf16, jit script path script_graph = script_fused_model.graph_for(x2) fused_sresult = script_fused_model(x2) # bf 16, jit trace path trace_graph = trace_fused_model.graph_for(x3) fused_tresult = trace_fused_model(x3) # disable mix_bf16_fp32 when the calculation is done # to avoid affecting other scripts core.disable_mix_bf16_fp32() self.assertEqual(fused_sresult, result, prec=prec) self.assertEqual(fused_tresult, result, prec=prec) # check if the fused node exists in the graph if kind_in_graph is not None: self.assertTrue( any(n.kind() == kind_in_graph for n in script_graph.nodes())) self.assertTrue( any(n.kind() == kind_in_graph for n in trace_graph.nodes())) # check if certain node does not exist in the graph if kind_not_in_graph is not None: self.assertTrue( all(n.kind() != kind_not_in_graph for n in script_graph.nodes())) self.assertTrue( all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))