Exemplo n.º 1
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Exemplo n.º 2
0
 def __init__(self):
     super().__init__("trt-network")
     self.subscribe_args(ModelArgs(model_required=False, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
Exemplo n.º 3
0
def engine_loader_args():
    return ArgGroupTestHelper(TrtEngineLoaderArgs(),
                              deps=[
                                  ModelArgs(),
                                  OnnxLoaderArgs(),
                                  TrtConfigArgs(),
                                  TrtPluginLoaderArgs(),
                                  TrtNetworkLoaderArgs()
                              ])
Exemplo n.º 4
0
 def __init__(self):
     super().__init__("model")
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False, outputs=False))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs(outputs=False))
     self.subscribe_args(TrtEngineLoaderArgs())
Exemplo n.º 5
0
 def __init__(self):
     super().__init__("convert")
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output=False))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
Exemplo n.º 6
0
    def test_load_network(self):
        arg_group = ArgGroupTestHelper(
            TrtNetworkLoaderArgs(),
            deps=[ModelArgs(),
                  OnnxLoaderArgs(),
                  TrtPluginLoaderArgs()])
        arg_group.parse_args([
            ONNX_MODELS["identity_identity"].path,
            "--trt-outputs=identity_out_0"
        ])

        builder, network, parser = arg_group.load_network()
        with builder, network:
            assert network.num_outputs == 1
            assert network.get_output(0).name == "identity_out_0"
Exemplo n.º 7
0
 def __init__(self, name, strict_types_default=None, prefer_artifacts=True):
     super().__init__(name)
     self.subscribe_args(
         ArtifactSorterArgs("polygraphy_debug.engine",
                            prefer_artifacts=prefer_artifacts))
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(
         TrtConfigArgs(strict_types_default=strict_types_default))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
Exemplo n.º 8
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())