Ejemplo n.º 1
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(PluginRefArgs())
     self.subscribe_args(
         TrtConfigArgs(random_data_calib_warning=False
                       ))  # We run calibration with the inference-time data
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Ejemplo n.º 2
0
 def __init__(self, name):
     super().__init__(name)
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ModelArgs(model_required=True))
     self.subscribe_args(OnnxLoaderArgs(outputs=False))
     self.subscribe_args(TrtLoaderArgs())
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(ComparatorRunArgs(iters=False, write=False))
     self.subscribe_args(ComparatorCompareArgs())
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs())
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxLoaderArgs())
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(OnnxtfRunnerArgs())
     self.subscribe_args(TrtLoaderArgs(network_api=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Ejemplo n.º 4
0
 def __init__(self):
     super().__init__("run")
     self.subscribe_args(ModelArgs())
     self.subscribe_args(TfLoaderArgs(tftrt=True))
     self.subscribe_args(TfConfigArgs())
     self.subscribe_args(TfRunnerArgs())
     self.subscribe_args(Tf2OnnxLoaderArgs())
     self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None))
     self.subscribe_args(OnnxShapeInferenceArgs())
     self.subscribe_args(OnnxLoaderArgs(save=True))
     self.subscribe_args(OnnxrtRunnerArgs())
     self.subscribe_args(TrtConfigArgs())
     self.subscribe_args(TrtPluginLoaderArgs())
     self.subscribe_args(TrtNetworkLoaderArgs())
     self.subscribe_args(
         TrtEngineSaveArgs(output="save-engine", short_opt=None))
     self.subscribe_args(TrtEngineLoaderArgs(save=True))
     self.subscribe_args(TrtRunnerArgs())
     self.subscribe_args(TrtLegacyArgs())
     self.subscribe_args(DataLoaderArgs())
     self.subscribe_args(ComparatorRunArgs())
     self.subscribe_args(ComparatorCompareArgs())
Ejemplo n.º 5
0
 def test_invalid_error_stat(self, args):
     with pytest.raises(PolygraphyException, match="Invalid choice"):
         arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
         arg_group.parse_args(["--check-error-stat"] + args)
Ejemplo n.º 6
0
    def test_error_stat_per_output(self, args, expected):
        arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
        arg_group.parse_args(["--check-error-stat"] + args)

        assert arg_group.check_error_stat == expected
Ejemplo n.º 7
0
    def test_error_stat(self, check_error_stat):
        arg_group = ArgGroupTestHelper(ComparatorCompareArgs())
        arg_group.parse_args(
            ["--check-error-stat={:}".format(check_error_stat)])

        assert arg_group.check_error_stat == {"": check_error_stat}