예제 #1
0
 def test_basic(self):
     model = TF_MODELS["identity"]
     with TfRunner(SessionFromGraph(model.loader)) as runner:
         assert runner.is_active
         model.check_runner(runner)
     assert not runner.is_active
     assert runner._cached_input_metadata is None
예제 #2
0
 def test_basic(self):
     model = TF_MODELS["identity"]
     with TfRunner(SessionFromGraph(model.loader)) as runner:
         assert runner.is_active
         model.check_runner(runner)
         assert runner.last_inference_time() is not None
     assert not runner.is_active
예제 #3
0
 def test_save_timeline(self):
     model = TF_MODELS["identity"]
     with tempfile.NamedTemporaryFile() as outpath:
         with TfRunner(SessionFromGraph(model.loader),
                       allow_growth=True,
                       save_timeline=outpath.name) as runner:
             model.check_runner(runner)
             check_file_non_empty(outpath.name)
예제 #4
0
 def test_error_on_wrong_dtype_feed_dict(self):
     model = TF_MODELS["identity"]
     with TfRunner(SessionFromGraph(model.loader)) as runner:
         with pytest.raises(PolygraphyException, match="unexpected dtype."):
             runner.infer({
                 "Input:0":
                 np.ones(shape=(1, 15, 25, 30), dtype=np.int32)
             })
예제 #5
0
 def test_error_on_wrong_name_feed_dict(self, names, err):
     model = TF_MODELS["identity"]
     with TfRunner(SessionFromGraph(model.loader)) as runner:
         with pytest.raises(PolygraphyException, match=err):
             runner.infer({
                 name: np.ones(shape=(1, 15, 25, 30), dtype=np.float32)
                 for name in names
             })
예제 #6
0
    def test_multiple_runners(self):
        load_tf = TF_MODELS["identity"].loader
        build_tf_session = SessionFromGraph(load_tf)
        load_serialized_onnx = BytesFromOnnx(OnnxFromTfGraph(load_tf))
        build_onnxrt_session = SessionFromOnnxBytes(load_serialized_onnx)
        load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_serialized_onnx))

        runners = [
            TfRunner(build_tf_session),
            OnnxrtRunner(build_onnxrt_session),
            TrtRunner(load_engine),
        ]

        run_results = Comparator.run(runners)
        compare_func = CompareFunc.basic_compare_func(check_shapes=version(trt.__version__) >= version("7.0"))
        assert bool(Comparator.compare_accuracy(run_results, compare_func=compare_func))
        assert len(list(run_results.values())[0]) == 1 # Default number of iterations