def test_data_loader_script(self): arg_group = ArgGroupTestHelper(DataLoaderArgs()) with tempfile.NamedTemporaryFile("w+", suffix=".py") as f: f.write( dedent(""" import numpy as np def my_load_data(): for _ in range(5): yield {"inp": np.ones((3, 5), dtype=np.float32) * 6.4341} """)) f.flush() arg_group.parse_args([ "--data-loader-script", f.name, "--data-loader-func-name=my_load_data" ]) assert arg_group.data_loader_script == f.name assert arg_group.data_loader_func_name == "my_load_data" data_loader = arg_group.get_data_loader() data = list(data_loader) assert len(data) == 5 assert all( np.all(d["inp"] == np.ones((3, 5), dtype=np.float32) * 6.4341) for d in data)
def __init__(self): super().__init__("run") self.subscribe_args(ModelArgs()) self.subscribe_args(TfLoaderArgs(tftrt=True)) self.subscribe_args(TfConfigArgs()) self.subscribe_args(TfRunnerArgs()) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(save=True)) self.subscribe_args(OnnxrtRunnerArgs()) self.subscribe_args(PluginRefArgs()) self.subscribe_args( TrtConfigArgs(random_data_calib_warning=False )) # We run calibration with the inference-time data self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args( TrtEngineSaveArgs(output="save-engine", short_opt=None)) self.subscribe_args(TrtEngineLoaderArgs(save=True)) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(TrtLegacyArgs()) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ComparatorRunArgs()) self.subscribe_args(ComparatorCompareArgs())
def __init__(self): super().__init__("reduce") self.subscribe_args(ArtifactSorterArgs("polygraphy_debug.onnx", prefer_artifacts=False)) self.subscribe_args(ModelArgs(model_required=True, inputs="--model-inputs", model_type="onnx")) self.subscribe_args(OnnxSaveArgs()) self.subscribe_args(OnnxShapeInferenceArgs(default=True, enable_force_fallback=True)) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(DataLoaderArgs()) # For fallback shape inference
def test_input_metadata(self): arg_group = ArgGroupTestHelper(DataLoaderArgs(), deps=[ModelArgs()]) arg_group.parse_args(["--input-shapes", "test0:[1,1,1]", "test1:[2,32,2]"]) data_loader = arg_group.get_data_loader() for feed_dict in data_loader: assert feed_dict["test0"].shape == (1, 1, 1) assert feed_dict["test1"].shape == (2, 32, 2)
def test_shape_inference_disabled_on_fallback(self): arg_group = ArgGroupTestHelper( OnnxShapeInferenceArgs(default=True, enable_force_fallback=True), deps=[DataLoaderArgs()] ) arg_group.parse_args([]) assert arg_group.do_shape_inference arg_group.parse_args(["--force-fallback-shape-inference"]) assert not arg_group.do_shape_inference
def test_override_input_metadata(self): arg_group = ArgGroupTestHelper(DataLoaderArgs(), deps=[ModelArgs()]) arg_group.parse_args([]) data_loader = arg_group.get_data_loader( user_input_metadata=TensorMetadata().add( "test0", dtype=np.float32, shape=(4, 4))) for feed_dict in data_loader: assert feed_dict["test0"].shape == (4, 4)
def __init__(self, name): super().__init__(name) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ModelArgs(model_required=True)) self.subscribe_args(OnnxLoaderArgs(outputs=False)) self.subscribe_args(TrtLoaderArgs()) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(ComparatorRunArgs(iters=False, write=False)) self.subscribe_args(ComparatorCompareArgs())
def __init__(self): super().__init__("extract") self.subscribe_args( ModelArgs(model_required=True, inputs="--model-inputs", model_type="onnx")) self.subscribe_args(DataLoaderArgs()) self.subscribe_args( OnnxShapeInferenceArgs(default=False, enable_force_fallback=True)) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(OnnxSaveArgs(required=True))
def test_parsing(self, case): arg_group = ArgGroupTestHelper(DataLoaderArgs()) cli_args, attrs, expected, expected_dl = util.unpack_args(case, 4) expected_dl = expected_dl or expected arg_group.parse_args(cli_args) data_loader = arg_group.get_data_loader() for attr, exp, exp_dl in zip(attrs, expected, expected_dl): assert getattr(arg_group, attr) == exp assert getattr(data_loader, attr) == exp_dl
def __init__(self): super().__init__("sanitize") self.subscribe_args( ModelArgs(model_required=True, inputs="--override-inputs", model_type="onnx")) self.subscribe_args(DataLoaderArgs()) self.subscribe_args( OnnxShapeInferenceArgs(default=True, enable_force_fallback=True)) self.subscribe_args(OnnxLoaderArgs(output_prefix="")) self.subscribe_args(OnnxSaveArgs(infer_shapes=True, required=True))
def __init__(self): super().__init__("convert") self.subscribe_args(ModelArgs(model_required=True)) self.subscribe_args(TfLoaderArgs(artifacts=False)) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output=False)) self.subscribe_args(DataLoaderArgs()) # For int8 calibration self.subscribe_args(TrtConfigArgs()) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args(TrtEngineLoaderArgs()) self.subscribe_args(TrtEngineSaveArgs(output=False))
def __init__(self): super().__init__("extract") self.subscribe_args( ModelArgs( model_required=True, inputs="--model-inputs", model_type="onnx", inputs_doc="Input shapes to use when generating data to run fallback shape inference. " "Has no effect if fallback shape inference is not run", ) ) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(OnnxShapeInferenceArgs(default=False, enable_force_fallback=True)) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(OnnxSaveArgs(required=True))
def __init__(self, name, strict_types_default=None, prefer_artifacts=True): super().__init__(name) self.subscribe_args( ArtifactSorterArgs("polygraphy_debug.engine", prefer_artifacts=prefer_artifacts)) self.subscribe_args(ModelArgs(model_required=True, inputs=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(output_prefix=None)) self.subscribe_args(DataLoaderArgs()) # For int8 calibration self.subscribe_args( TrtConfigArgs(strict_types_default=strict_types_default)) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args(TrtEngineLoaderArgs()) self.subscribe_args(TrtEngineSaveArgs(output=False))
def __init__(self): super().__init__("run") self.subscribe_args(ModelArgs()) self.subscribe_args(TfLoaderArgs()) self.subscribe_args(TfConfigArgs()) self.subscribe_args(TfRunnerArgs()) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxLoaderArgs()) self.subscribe_args(OnnxrtRunnerArgs()) self.subscribe_args(OnnxtfRunnerArgs()) self.subscribe_args(TrtLoaderArgs(network_api=True)) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(TrtLegacyArgs()) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ComparatorRunArgs()) self.subscribe_args(ComparatorCompareArgs())
def __init__(self): super().__init__("run") self.subscribe_args(ModelArgs()) self.subscribe_args(TfLoaderArgs(tftrt=True)) self.subscribe_args(TfConfigArgs()) self.subscribe_args(TfRunnerArgs()) self.subscribe_args(Tf2OnnxLoaderArgs()) self.subscribe_args(OnnxSaveArgs(output="save-onnx", short_opt=None)) self.subscribe_args(OnnxShapeInferenceArgs()) self.subscribe_args(OnnxLoaderArgs(save=True)) self.subscribe_args(OnnxrtRunnerArgs()) self.subscribe_args(TrtConfigArgs()) self.subscribe_args(TrtPluginLoaderArgs()) self.subscribe_args(TrtNetworkLoaderArgs()) self.subscribe_args( TrtEngineSaveArgs(output="save-engine", short_opt=None)) self.subscribe_args(TrtEngineLoaderArgs(save=True)) self.subscribe_args(TrtRunnerArgs()) self.subscribe_args(TrtLegacyArgs()) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(ComparatorRunArgs()) self.subscribe_args(ComparatorCompareArgs())
def __init__(self): super().__init__("trt-config") self.subscribe_args(ModelArgs(model_required=False)) self.subscribe_args(DataLoaderArgs()) self.subscribe_args(TrtConfigArgs())
def test_val_range_errors(self, opts, expected_err): arg_group = ArgGroupTestHelper(DataLoaderArgs()) with pytest.raises(PolygraphyException, match=expected_err): arg_group.parse_args(opts)
def __init__(self, name, inputs=None, data=False, shape_inference_default=None): super().__init__(name) self.subscribe_args(ModelArgs(model_required=True, inputs=inputs, model_type="onnx")) self.subscribe_args(OnnxLoaderArgs(write=False, outputs=False, shape_inference_default=shape_inference_default)) if data: self.subscribe_args(DataLoaderArgs())
def trt_config_args(): return ArgGroupTestHelper(TrtConfigArgs(), deps=[ModelArgs(), DataLoaderArgs()])