コード例 #1
0
ファイル: insert.py プロジェクト: stjordanis/TensorRT
 def __init__(self):
     super().__init__("insert")
     self.subscribe_args(OnnxNodeArgs())
     self.subscribe_args(ModelArgs(model_required=True, inputs="--model-inputs", model_type="onnx"))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(OnnxSaveArgs(infer_shapes=True, required=True))
コード例 #2
0
ファイル: test_loader.py プロジェクト: clayne/TensorRT
    def test_shape_inference_disabled_on_fallback(self):
        arg_group = ArgGroupTestHelper(OnnxShapeInferenceArgs(default=True, enable_force_fallback=True), deps=[DataLoaderArgs()])
        arg_group.parse_args([])
        assert arg_group.do_shape_inference

        arg_group.parse_args(["--force-fallback-shape-inference"])
        assert not arg_group.do_shape_inference
コード例 #3
0
ファイル: run.py プロジェクト: phongphuhanam/TensorRT
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
コード例 #4
0
ファイル: reduce.py プロジェクト: phongphuhanam/TensorRT
 def __init__(self):
     super().__init__("reduce")
     self.subscribe_args(ArtifactSorterArgs("polygraphy_debug.onnx", prefer_artifacts=False))
     self.subscribe_args(ModelArgs(model_required=True, inputs="--model-inputs", model_type="onnx"))
     self.subscribe_args(OnnxSaveArgs())
     self.subscribe_args(OnnxShapeInferenceArgs(default=True, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For fallback shape inference
コード例 #5
0
 def __init__(self):
     super().__init__("model")
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(TfLoaderArgs(artifacts=False, outputs=False))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs(outputs=False))
     self.subscribe_args(TrtEngineLoaderArgs())
コード例 #6
0
ファイル: sanitize.py プロジェクト: stjordanis/TensorRT
 def __init__(self):
     super().__init__("sanitize")
     self.subscribe_args(
         ModelArgs(model_required=True,
                   inputs="--override-inputs",
                   model_type="onnx"))
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(
         OnnxShapeInferenceArgs(default=True, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=""))
     self.subscribe_args(OnnxSaveArgs(infer_shapes=True, required=True))
コード例 #7
0
 def __init__(self):
     super().__init__("extract")
     self.subscribe_args(
         ModelArgs(model_required=True,
                   inputs="--model-inputs",
                   model_type="onnx"))
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(
         OnnxShapeInferenceArgs(default=False, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(OnnxSaveArgs(required=True))
コード例 #8
0
ファイル: test_loader.py プロジェクト: phongphuhanam/TensorRT
    def test_shape_inference(self):
        # When using shape inference, we should load directly from the path
        arg_group = ArgGroupTestHelper(OnnxLoaderArgs(), deps=[ModelArgs(), OnnxShapeInferenceArgs()])
        model = ONNX_MODELS["identity"]
        arg_group.parse_args([model.path, "--shape-inference"])

        assert arg_group.should_use_onnx_loader()

        script = Script()
        arg_group.add_onnx_loader(script)

        expected_loader = "InferShapes({:})".format(repr(model.path))
        assert expected_loader in str(script)
コード例 #9
0
ファイル: convert.py プロジェクト: clayne/TensorRT
 def __init__(self):
     super().__init__("convert")
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output=False))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
コード例 #10
0
ファイル: base.py プロジェクト: phongphuhanam/TensorRT
 def __init__(self, name, strict_types_default=None, prefer_artifacts=True):
     super().__init__(name)
     self.subscribe_args(
         ArtifactSorterArgs("polygraphy_debug.engine",
                            prefer_artifacts=prefer_artifacts))
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(
         TrtConfigArgs(strict_types_default=strict_types_default))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
コード例 #11
0
ファイル: test_loader.py プロジェクト: phongphuhanam/TensorRT
    def test_shape_inference_ext_data(self):
        arg_group = ArgGroupTestHelper(OnnxLoaderArgs(), deps=[ModelArgs(), OnnxShapeInferenceArgs()])
        model = ONNX_MODELS["ext_weights"]
        arg_group.parse_args([model.path, "--external-data-dir", model.ext_data, "--shape-inference"])

        assert arg_group.should_use_onnx_loader()

        script = Script()
        arg_group.add_onnx_loader(script)

        expected_loader = "InferShapes({:}, external_data_dir={:})".format(repr(model.path), repr(model.ext_data))
        assert expected_loader in str(script)

        model = arg_group.load_onnx()
        _check_ext_weights_model(model)
コード例 #12
0
 def __init__(self):
     super().__init__("capability")
     self.subscribe_args(
         ModelArgs(model_required=True, inputs=None, model_type="onnx"))
     self.subscribe_args(OnnxShapeInferenceArgs(default=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     # Disallow ext data path since we're writing multiple models - otherwise, it'll be clobbered each time.
     self.subscribe_args(
         OnnxSaveArgs(
             allow_ext_data_path=False,
             custom_help=
             "Directory to write out supported and unsupported subgraphs. "
             "Defaults to 'polygraphy_capability_dumps' in the current directory",
             default_output_path="polygraphy_capability_dumps",
         ))
コード例 #13
0
ファイル: extract.py プロジェクト: phongphuhanam/TensorRT
 def __init__(self):
     super().__init__("extract")
     self.subscribe_args(
         ModelArgs(
             model_required=True,
             inputs="--model-inputs",
             model_type="onnx",
             inputs_doc="Input shapes to use when generating data to run fallback shape inference. "
             "Has no effect if fallback shape inference is not run",
         )
     )
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs(default=False, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(OnnxSaveArgs(required=True))
コード例 #14
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())