def test_basic(self): model = TF_MODELS["identity"] with TfRunner(SessionFromGraph(model.loader)) as runner: assert runner.is_active model.check_runner(runner) assert not runner.is_active assert runner._cached_input_metadata is None
def test_basic(self): model = TF_MODELS["identity"] with TfRunner(SessionFromGraph(model.loader)) as runner: assert runner.is_active model.check_runner(runner) assert runner.last_inference_time() is not None assert not runner.is_active
def test_save_timeline(self): model = TF_MODELS["identity"] with tempfile.NamedTemporaryFile() as outpath: with TfRunner(SessionFromGraph(model.loader), allow_growth=True, save_timeline=outpath.name) as runner: model.check_runner(runner) check_file_non_empty(outpath.name)
def test_error_on_wrong_dtype_feed_dict(self): model = TF_MODELS["identity"] with TfRunner(SessionFromGraph(model.loader)) as runner: with pytest.raises(PolygraphyException, match="unexpected dtype."): runner.infer({ "Input:0": np.ones(shape=(1, 15, 25, 30), dtype=np.int32) })
def test_error_on_wrong_name_feed_dict(self, names, err): model = TF_MODELS["identity"] with TfRunner(SessionFromGraph(model.loader)) as runner: with pytest.raises(PolygraphyException, match=err): runner.infer({ name: np.ones(shape=(1, 15, 25, 30), dtype=np.float32) for name in names })
def test_multiple_runners(self): load_tf = TF_MODELS["identity"].loader build_tf_session = SessionFromGraph(load_tf) load_serialized_onnx = BytesFromOnnx(OnnxFromTfGraph(load_tf)) build_onnxrt_session = SessionFromOnnxBytes(load_serialized_onnx) load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_serialized_onnx)) runners = [ TfRunner(build_tf_session), OnnxrtRunner(build_onnxrt_session), TrtRunner(load_engine), ] run_results = Comparator.run(runners) compare_func = CompareFunc.basic_compare_func(check_shapes=version(trt.__version__) >= version("7.0")) assert bool(Comparator.compare_accuracy(run_results, compare_func=compare_func)) assert len(list(run_results.values())[0]) == 1 # Default number of iterations