Esempio n. 1
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Esempio n. 2
0
def engine_loader_args():
    return ArgGroupTestHelper(TrtEngineLoaderArgs(),
                              deps=[
                                  ModelArgs(),
                                  OnnxLoaderArgs(),
                                  TrtConfigArgs(),
                                  TrtPluginLoaderArgs(),
                                  TrtNetworkLoaderArgs()
                              ])
Esempio n. 3
0
 def __init__(self):
     super().__init__("model")
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False, outputs=False))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs(outputs=False))
     self.subscribe_args(TrtEngineLoaderArgs())
Esempio n. 4
0
 def __init__(self):
     super().__init__("convert")
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output=False))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
Esempio n. 5
0
 def __init__(self, name, strict_types_default=None, prefer_artifacts=True):
     super().__init__(name)
     self.subscribe_args(
         ArtifactSorterArgs("polygraphy_debug.engine",
                            prefer_artifacts=prefer_artifacts))
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(
         TrtConfigArgs(strict_types_default=strict_types_default))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
Esempio n. 6
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())