def __init__(self): super().__init__("run") self.subscribe_args(ModelArgs()) self.subscribe_args(TfLoaderArgs(tftrt=True)) self.subscribe_args(TfConfigArgs()) self.subscribe_args(TfRunnerArgs()) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(save=True)) self.subscribe_args(OnnxrtRunnerArgs()) self.subscribe_args(PluginRefArgs()) self.subscribe_args( TrtConfigArgs(random_data_calib_warning=False )) # We run calibration with the inference-time data self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args( TrtEngineSaveArgs(output="save-engine", short_opt=None)) self.subscribe_args(TrtEngineLoaderArgs(save=True)) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(TrtLegacyArgs()) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ComparatorRunArgs()) self.subscribe_args(ComparatorCompareArgs())
def __init__(self): super().__init__("trt-network") self.subscribe_args(ModelArgs(model_required=False, inputs=None)) self.subscribe_args(TfLoaderArgs(artifacts=False)) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxLoaderArgs()) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs())
def engine_loader_args(): return ArgGroupTestHelper(TrtEngineLoaderArgs(), deps=[ ModelArgs(), OnnxLoaderArgs(), TrtConfigArgs(), TrtPluginLoaderArgs(), TrtNetworkLoaderArgs() ])
def __init__(self): super().__init__("model") self.subscribe_args(ModelArgs(model_required=True, inputs=None)) self.subscribe_args(TfLoaderArgs(artifacts=False, outputs=False)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs(outputs=False)) self.subscribe_args(TrtEngineLoaderArgs())
def __init__(self): super().__init__("convert") self.subscribe_args(ModelArgs(model_required=True)) self.subscribe_args(TfLoaderArgs(artifacts=False)) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output=False)) self.subscribe_args(DataLoaderArgs()) # For int8 calibration self.subscribe_args(TrtConfigArgs()) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args(TrtEngineLoaderArgs()) self.subscribe_args(TrtEngineSaveArgs(output=False))
def test_load_network(self): arg_group = ArgGroupTestHelper( TrtNetworkLoaderArgs(), deps=[ModelArgs(), OnnxLoaderArgs(), TrtPluginLoaderArgs()]) arg_group.parse_args([ ONNX_MODELS["identity_identity"].path, "--trt-outputs=identity_out_0" ]) builder, network, parser = arg_group.load_network() with builder, network: assert network.num_outputs == 1 assert network.get_output(0).name == "identity_out_0"
def __init__(self, name, strict_types_default=None, prefer_artifacts=True): super().__init__(name) self.subscribe_args( ArtifactSorterArgs("polygraphy_debug.engine", prefer_artifacts=prefer_artifacts)) self.subscribe_args(ModelArgs(model_required=True, inputs=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(DataLoaderArgs()) # For int8 calibration self.subscribe_args( TrtConfigArgs(strict_types_default=strict_types_default)) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args(TrtEngineLoaderArgs()) self.subscribe_args(TrtEngineSaveArgs(output=False))
def __init__(self): super().__init__("run") self.subscribe_args(ModelArgs()) self.subscribe_args(TfLoaderArgs(tftrt=True)) self.subscribe_args(TfConfigArgs()) self.subscribe_args(TfRunnerArgs()) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(save=True)) self.subscribe_args(OnnxrtRunnerArgs()) self.subscribe_args(TrtConfigArgs()) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args( TrtEngineSaveArgs(output="save-engine", short_opt=None)) self.subscribe_args(TrtEngineLoaderArgs(save=True)) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(TrtLegacyArgs()) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ComparatorRunArgs()) self.subscribe_args(ComparatorCompareArgs())