예제 #1
0
    def test_data_loader_script(self):
        arg_group = ArgGroupTestHelper(DataLoaderArgs())

        with tempfile.NamedTemporaryFile("w+", suffix=".py") as f:
            f.write(
                dedent("""
                import numpy as np

                def my_load_data():
                    for _ in range(5):
                        yield {"inp": np.ones((3, 5), dtype=np.float32) * 6.4341}
            """))
            f.flush()

            arg_group.parse_args([
                "--data-loader-script", f.name,
                "--data-loader-func-name=my_load_data"
            ])

            assert arg_group.data_loader_script == f.name
            assert arg_group.data_loader_func_name == "my_load_data"

            data_loader = arg_group.get_data_loader()
            data = list(data_loader)
            assert len(data) == 5
            assert all(
                np.all(d["inp"] == np.ones((3, 5), dtype=np.float32) * 6.4341)
                for d in data)
예제 #2
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
예제 #3
0
 def __init__(self):
     super().__init__("reduce")
     self.subscribe_args(ArtifactSorterArgs("polygraphy_debug.onnx", prefer_artifacts=False))
     self.subscribe_args(ModelArgs(model_required=True, inputs="--model-inputs", model_type="onnx"))
     self.subscribe_args(OnnxSaveArgs())
     self.subscribe_args(OnnxShapeInferenceArgs(default=True, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For fallback shape inference
예제 #4
0
    def test_input_metadata(self):
        arg_group = ArgGroupTestHelper(DataLoaderArgs(), deps=[ModelArgs()])
        arg_group.parse_args(["--input-shapes", "test0:[1,1,1]", "test1:[2,32,2]"])
        data_loader = arg_group.get_data_loader()

        for feed_dict in data_loader:
            assert feed_dict["test0"].shape == (1, 1, 1)
            assert feed_dict["test1"].shape == (2, 32, 2)
예제 #5
0
    def test_shape_inference_disabled_on_fallback(self):
        arg_group = ArgGroupTestHelper(
            OnnxShapeInferenceArgs(default=True, enable_force_fallback=True), deps=[DataLoaderArgs()]
        )
        arg_group.parse_args([])
        assert arg_group.do_shape_inference

        arg_group.parse_args(["--force-fallback-shape-inference"])
        assert not arg_group.do_shape_inference
예제 #6
0
    def test_override_input_metadata(self):
        arg_group = ArgGroupTestHelper(DataLoaderArgs(), deps=[ModelArgs()])
        arg_group.parse_args([])
        data_loader = arg_group.get_data_loader(
            user_input_metadata=TensorMetadata().add(
                "test0", dtype=np.float32, shape=(4, 4)))

        for feed_dict in data_loader:
            assert feed_dict["test0"].shape == (4, 4)
예제 #7
0
 def __init__(self, name):
     super().__init__(name)
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(OnnxLoaderArgs(outputs=False))
     self.subscribe_args(TrtLoaderArgs())
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(ComparatorRunArgs(iters=False, write=False))
     self.subscribe_args(ComparatorCompareArgs())
예제 #8
0
 def __init__(self):
     super().__init__("extract")
     self.subscribe_args(
         ModelArgs(model_required=True,
                   inputs="--model-inputs",
                   model_type="onnx"))
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(
         OnnxShapeInferenceArgs(default=False, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(OnnxSaveArgs(required=True))
예제 #9
0
    def test_parsing(self, case):
        arg_group = ArgGroupTestHelper(DataLoaderArgs())
        cli_args, attrs, expected, expected_dl = util.unpack_args(case, 4)
        expected_dl = expected_dl or expected

        arg_group.parse_args(cli_args)
        data_loader = arg_group.get_data_loader()

        for attr, exp, exp_dl in zip(attrs, expected, expected_dl):
            assert getattr(arg_group, attr) == exp
            assert getattr(data_loader, attr) == exp_dl
예제 #10
0
 def __init__(self):
     super().__init__("sanitize")
     self.subscribe_args(
         ModelArgs(model_required=True,
                   inputs="--override-inputs",
                   model_type="onnx"))
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(
         OnnxShapeInferenceArgs(default=True, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=""))
     self.subscribe_args(OnnxSaveArgs(infer_shapes=True, required=True))
예제 #11
0
파일: convert.py 프로젝트: clayne/TensorRT
 def __init__(self):
     super().__init__("convert")
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(TfLoaderArgs(artifacts=False))
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output=False))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
예제 #12
0
 def __init__(self):
     super().__init__("extract")
     self.subscribe_args(
         ModelArgs(
             model_required=True,
             inputs="--model-inputs",
             model_type="onnx",
             inputs_doc="Input shapes to use when generating data to run fallback shape inference. "
             "Has no effect if fallback shape inference is not run",
         )
     )
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(OnnxShapeInferenceArgs(default=False, enable_force_fallback=True))
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(OnnxSaveArgs(required=True))
예제 #13
0
 def __init__(self, name, strict_types_default=None, prefer_artifacts=True):
     super().__init__(name)
     self.subscribe_args(
         ArtifactSorterArgs("polygraphy_debug.engine",
                            prefer_artifacts=prefer_artifacts))
     self.subscribe_args(ModelArgs(model_required=True, inputs=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(output_prefix=None))
     self.subscribe_args(DataLoaderArgs())  # For int8 calibration
     self.subscribe_args(
         TrtConfigArgs(strict_types_default=strict_types_default))
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(TrtEngineLoaderArgs())
     self.subscribe_args(TrtEngineSaveArgs(output=False))
예제 #14
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs())
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(OnnxtfRunnerArgs())
     self.subscribe_args(TrtLoaderArgs(network_api=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
예제 #15
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
예제 #16
0
 def __init__(self):
     super().__init__("trt-config")
     self.subscribe_args(ModelArgs(model_required=False))
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(TrtConfigArgs())
예제 #17
0
    def test_val_range_errors(self, opts, expected_err):
        arg_group = ArgGroupTestHelper(DataLoaderArgs())

        with pytest.raises(PolygraphyException, match=expected_err):
            arg_group.parse_args(opts)
예제 #18
0
 def __init__(self, name, inputs=None, data=False, shape_inference_default=None):
     super().__init__(name)
     self.subscribe_args(ModelArgs(model_required=True, inputs=inputs, model_type="onnx"))
     self.subscribe_args(OnnxLoaderArgs(write=False, outputs=False, shape_inference_default=shape_inference_default))
     if data:
         self.subscribe_args(DataLoaderArgs())
예제 #19
0
def trt_config_args():
    return ArgGroupTestHelper(TrtConfigArgs(),
                              deps=[ModelArgs(), DataLoaderArgs()])