def test_restricted_flags(self, trt_config_args): trt_config_args.parse_args(["--trt-safety-restricted"]) builder, network = create_network() with builder, network, trt_config_args.create_config( builder, network=network) as config: assert config.get_flag(getattr(trt.BuilderFlag, "SAFETY_SCOPE"))
def test_config_script(self): arg_group = ArgGroupTestHelper(TrtConfigArgs()) with tempfile.NamedTemporaryFile("w+", suffix=".py") as f: f.write( dedent(""" from polygraphy.backend.trt import CreateConfig from polygraphy import func import tensorrt as trt @func.extend(CreateConfig()) def my_load_config(config): config.set_flag(trt.BuilderFlag.FP16) """)) f.flush() arg_group.parse_args([ "--trt-config-script", f.name, "--trt-config-func-name=my_load_config" ]) assert arg_group.trt_config_script == f.name assert arg_group.trt_config_func_name == "my_load_config" builder, network = create_network() with builder, network, arg_group.create_config(builder, network) as config: assert isinstance(config, trt.IBuilderConfig) assert config.get_flag(trt.BuilderFlag.FP16)
def test_create_config(self, trt_config_args): trt_config_args.parse_args([]) builder, network = create_network() with builder, network, trt_config_args.create_config( builder, network=network) as config: assert isinstance(config, trt.IBuilderConfig)
def test_workspace(self, trt_config_args, workspace, expected): trt_config_args.parse_args(["--workspace", workspace]) assert trt_config_args.workspace == expected builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: assert config.max_workspace_size == expected
def test_calibration_base_class(self, trt_config_args, base_class): trt_config_args.parse_args(["--int8", "--calibration-base-class", base_class]) assert trt_config_args.calibration_base_class.unwrap() == "trt.{:}".format(base_class) builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: assert isinstance(config.int8_calibrator, getattr(trt, base_class))
def test_dla(self, trt_config_args): trt_config_args.parse_args(["--use-dla"]) assert trt_config_args.use_dla builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: assert config.default_device_type == trt.DeviceType.DLA assert config.DLA_core == 0
def test_no_opts(self): with util.NamedTemporaryFile("w+", suffix=".py") as template: run_polygraphy_template(["trt-config", "-o", template.name]) builder, network = create_network() create_config = InvokeFromScript(template.name, "load_config") with builder, network, create_config(builder, network) as config: assert isinstance(config, trt.IBuilderConfig)
def test_tactic_replay(self, trt_config_args): with util.NamedTemporaryFile(suffix=".json") as f: trt_config_args.parse_args(["--tactic-replay", f.name]) builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: recorder = config.algorithm_selector assert recorder.make_func == TacticRecorder assert recorder.path == f.name
def test_precision_flags(self, trt_config_args, arg, flag): if flag == "TF32" and mod.version(trt.__version__) < mod.version("7.1"): pytest.skip("TF32 support was added in 7.1") trt_config_args.parse_args([arg]) builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: assert config.get_flag(getattr(trt.BuilderFlag, flag))
def test_tactics(self, trt_config_args, opt, cls): with util.NamedTemporaryFile("w+", suffix=".json") as f: if opt == "--load-tactics": TacticReplayData().save(f) trt_config_args.parse_args([opt, f.name]) builder, network = create_network() with builder, network, trt_config_args.create_config(builder, network=network) as config: recorder = config.algorithm_selector assert recorder.make_func == cls assert recorder.path == f.name
def test_opts_basic(self): with util.NamedTemporaryFile("w+", suffix=".py") as template: run_polygraphy_template( ["trt-config", "--fp16", "--int8", "-o", template.name]) builder, network = create_network() create_config = InvokeFromScript(template.name, "load_config") with builder, network, create_config(builder, network) as config: assert isinstance(config, trt.IBuilderConfig) assert config.get_flag(trt.BuilderFlag.FP16) assert config.get_flag(trt.BuilderFlag.INT8)
def test_build_engine_custom_network(self, engine_loader_args): engine_loader_args.parse_args([]) builder, network = create_network() inp = network.add_input("input", dtype=trt.float32, shape=(1, 1)) out = network.add_identity(inp).get_output(0) out.name = "output" network.mark_output(out) with builder, network, engine_loader_args.build_engine( network=(builder, network)) as engine: assert isinstance(engine, trt.ICudaEngine) assert len(engine) == 2 assert engine[0] == "input" assert engine[1] == "output"
def test_legacy_calibrator_params(self, trt_config_args): quantile = 0.25 regression_cutoff = 0.9 trt_config_args.parse_args([ "--int8", "--calibration-base-class=IInt8LegacyCalibrator", "--quantile", str(quantile), "--regression-cutoff", str(regression_cutoff) ]) assert trt_config_args.quantile == quantile assert trt_config_args.regression_cutoff == regression_cutoff builder, network = create_network() with builder, network, trt_config_args.create_config( builder, network=network) as config: assert config.int8_calibrator.get_quantile() == quantile assert config.int8_calibrator.get_regression_cutoff( ) == regression_cutoff
def test_no_deps_profiles_int8(self): arg_group = ArgGroupTestHelper(TrtConfigArgs()) arg_group.parse_args([ "--trt-min-shapes=input:[1,25,25]", "--trt-opt-shapes=input:[2,25,25]", "--trt-max-shapes=input:[4,25,25]", "--int8" ]) for (min_shapes, opt_shapes, max_shapes) in arg_group.profile_dicts: assert min_shapes["input"] == [1, 25, 25] assert opt_shapes["input"] == [2, 25, 25] assert max_shapes["input"] == [4, 25, 25] builder, network = create_network() with builder, network, arg_group.create_config( builder, network=network) as config: assert isinstance(config, trt.IBuilderConfig) # Unfortunately there is no API to check the contents of the profile in a config. # The checks above will have to do. assert config.num_optimization_profiles == 1 assert config.get_flag(trt.BuilderFlag.INT8)
def test_tactic_sources(self, trt_config_args, opt, expected): trt_config_args.parse_args(opt) builder, network = create_network() with builder, network, trt_config_args.create_config( builder, network=network) as config: assert config.get_tactic_sources() == expected