def run_test(self, graph_ir, example_inputs): graph = parse_ir(graph_ir) jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) onnx_proto = _jit_graph_to_onnx_model( graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version) ort_sess = onnxruntime.InferenceSession(onnx_proto, providers=self.ort_providers) ort_outs = verification._run_ort(ort_sess, example_inputs) verification._ort_compare_with_pytorch(ort_outs, jit_outs, rtol=1e-3, atol=1e-7)
def run_test(self, graph_ir, example_inputs): graph = torch._C.parse_ir(graph_ir) jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) onnx_proto = _jit_graph_to_onnx_model( graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version ) ort_sess = onnxruntime.InferenceSession( onnx_proto, providers=self.ort_providers ) ort_outs = verification._run_ort(ort_sess, example_inputs) verification._compare_ort_pytorch_outputs( ort_outs, jit_outs, rtol=1e-3, atol=1e-7, check_shape=self.check_shape, check_dtype=self.check_dtype, acceptable_error_percentage=None, )