예제 #1
0
    def run_test(self, graph_ir, example_inputs):
        graph = parse_ir(graph_ir)
        jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)

        onnx_proto = _jit_graph_to_onnx_model(
            graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version)
        ort_sess = onnxruntime.InferenceSession(onnx_proto,
                                                providers=self.ort_providers)
        ort_outs = verification._run_ort(ort_sess, example_inputs)

        verification._ort_compare_with_pytorch(ort_outs,
                                               jit_outs,
                                               rtol=1e-3,
                                               atol=1e-7)
예제 #2
0
    def run_test(self, graph_ir, example_inputs):
        graph = torch._C.parse_ir(graph_ir)
        jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)

        onnx_proto = _jit_graph_to_onnx_model(
            graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
        )
        ort_sess = onnxruntime.InferenceSession(
            onnx_proto, providers=self.ort_providers
        )
        ort_outs = verification._run_ort(ort_sess, example_inputs)

        verification._compare_ort_pytorch_outputs(
            ort_outs,
            jit_outs,
            rtol=1e-3,
            atol=1e-7,
            check_shape=self.check_shape,
            check_dtype=self.check_dtype,
            acceptable_error_percentage=None,
        )