def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage( self, ): ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] verification._compare_ort_pytorch_outputs( ort_outs, pytorch_outs, rtol=1e-5, atol=1e-6, check_shape=True, check_dtype=False, acceptable_error_percentage=0.3, )
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage( self, ): ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] with self.assertRaises(AssertionError): verification._compare_ort_pytorch_outputs( ort_outs, pytorch_outs, rtol=1e-5, atol=1e-6, check_shape=True, check_dtype=False, acceptable_error_percentage=None, )
def run_test(self, graph_ir, example_inputs): graph = parse_ir(graph_ir) jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) onnx_proto = _jit_graph_to_onnx_model( graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version) ort_sess = onnxruntime.InferenceSession(onnx_proto, providers=self.ort_providers) ort_outs = verification._run_ort(ort_sess, example_inputs) verification._compare_ort_pytorch_outputs(ort_outs, jit_outs, rtol=1e-3, atol=1e-7)