コード例 #1
0
ファイル: run.py プロジェクト: phongphuhanam/TensorRT
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
コード例 #2
0
ファイル: precision.py プロジェクト: celidos/TensorRT_study
 def __init__(self, name):
     super().__init__(name)
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(OnnxLoaderArgs(outputs=False))
     self.subscribe_args(TrtLoaderArgs())
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(ComparatorRunArgs(iters=False, write=False))
     self.subscribe_args(ComparatorCompareArgs())
コード例 #3
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs())
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(OnnxtfRunnerArgs())
     self.subscribe_args(TrtLoaderArgs(network_api=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
コード例 #4
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
コード例 #5
0
ファイル: test_comparator.py プロジェクト: clayne/TensorRT
 def test_invalid_error_stat(self, args):
     with pytest.raises(PolygraphyException, match="Invalid choice"):
         arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
         arg_group.parse_args(["--check-error-stat"] + args)
コード例 #6
0
ファイル: test_comparator.py プロジェクト: clayne/TensorRT
    def test_error_stat_per_output(self, args, expected):
        arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
        arg_group.parse_args(["--check-error-stat"] + args)

        assert arg_group.check_error_stat == expected
コード例 #7
0
ファイル: test_comparator.py プロジェクト: clayne/TensorRT
    def test_error_stat(self, check_error_stat):
        arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
        arg_group.parse_args(
            ["--check-error-stat={:}".format(check_error_stat)])

        assert arg_group.check_error_stat == {"": check_error_stat}