Example #1
0
 def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
     self, ):
     ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
     pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
     verification._compare_ort_pytorch_outputs(
         ort_outs,
         pytorch_outs,
         rtol=1e-5,
         atol=1e-6,
         check_shape=True,
         check_dtype=False,
         acceptable_error_percentage=0.3,
     )
Example #2
0
 def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
     self, ):
     ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
     pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
     with self.assertRaises(AssertionError):
         verification._compare_ort_pytorch_outputs(
             ort_outs,
             pytorch_outs,
             rtol=1e-5,
             atol=1e-6,
             check_shape=True,
             check_dtype=False,
             acceptable_error_percentage=None,
         )
Example #3
0
    def run_test(self, graph_ir, example_inputs):
        graph = parse_ir(graph_ir)
        jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)

        onnx_proto = _jit_graph_to_onnx_model(
            graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version)
        ort_sess = onnxruntime.InferenceSession(onnx_proto,
                                                providers=self.ort_providers)
        ort_outs = verification._run_ort(ort_sess, example_inputs)

        verification._compare_ort_pytorch_outputs(ort_outs,
                                                  jit_outs,
                                                  rtol=1e-3,
                                                  atol=1e-7)