Example #1
0
    def test_load_graph(self):
        arg_group = ArgGroupTestHelper(TfLoaderArgs(), deps=[ModelArgs()])
        arg_group.parse_args([TF_MODELS["identity"].path, "--model-type=frozen"])
        graph, outputs = arg_group.load_graph()

        assert isinstance(graph, tf.Graph)
        assert outputs == ["Identity_2:0"]
Example #2
0
 def __init__(self):
     super().__init__("model")
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(
         TfLoaderArgs(tftrt=False, artifacts=False, outputs=False))
     self.subscribe_args(OnnxLoaderArgs(outputs=False))
     self.subscribe_args(TrtLoaderArgs(config=False, outputs=False))
Example #3
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Example #4
0
 def __init__(self):
     super().__init__("trt-network")
     self.subscribe_args(ModelArgs(model_required=False, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
Example #5
0
 def __init__(self):
     super().__init__("model")
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False, outputs=False))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs(outputs=False))
     self.subscribe_args(TrtEngineLoaderArgs())
Example #6
0
 def __init__(self):
     super().__init__("convert")
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output=False))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
Example #7
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs())
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(OnnxtfRunnerArgs())
     self.subscribe_args(TrtLoaderArgs(network_api=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Example #8
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())