Example #1
0
    def test_restricted_flags(self, trt_config_args):
        trt_config_args.parse_args(["--trt-safety-restricted"])
        builder, network = create_network()

        with builder, network, trt_config_args.create_config(
                builder, network=network) as config:
            assert config.get_flag(getattr(trt.BuilderFlag, "SAFETY_SCOPE"))
Example #2
0
    def test_config_script(self):
        arg_group = ArgGroupTestHelper(TrtConfigArgs())

        with tempfile.NamedTemporaryFile("w+", suffix=".py") as f:
            f.write(
                dedent("""
                from polygraphy.backend.trt import CreateConfig
                from polygraphy import func
                import tensorrt as trt

                @func.extend(CreateConfig())
                def my_load_config(config):
                    config.set_flag(trt.BuilderFlag.FP16)
            """))
            f.flush()

            arg_group.parse_args([
                "--trt-config-script", f.name,
                "--trt-config-func-name=my_load_config"
            ])
            assert arg_group.trt_config_script == f.name
            assert arg_group.trt_config_func_name == "my_load_config"

            builder, network = create_network()
            with builder, network, arg_group.create_config(builder,
                                                           network) as config:
                assert isinstance(config, trt.IBuilderConfig)
                assert config.get_flag(trt.BuilderFlag.FP16)
Example #3
0
    def test_create_config(self, trt_config_args):
        trt_config_args.parse_args([])
        builder, network = create_network()

        with builder, network, trt_config_args.create_config(
                builder, network=network) as config:
            assert isinstance(config, trt.IBuilderConfig)
Example #4
0
    def test_workspace(self, trt_config_args, workspace, expected):
        trt_config_args.parse_args(["--workspace", workspace])
        assert trt_config_args.workspace == expected

        builder, network = create_network()
        with builder, network, trt_config_args.create_config(builder, network=network) as config:
            assert config.max_workspace_size == expected
Example #5
0
    def test_calibration_base_class(self, trt_config_args, base_class):
        trt_config_args.parse_args(["--int8", "--calibration-base-class", base_class])
        assert trt_config_args.calibration_base_class.unwrap() == "trt.{:}".format(base_class)

        builder, network = create_network()
        with builder, network, trt_config_args.create_config(builder, network=network) as config:
            assert isinstance(config.int8_calibrator, getattr(trt, base_class))
Example #6
0
    def test_dla(self, trt_config_args):
        trt_config_args.parse_args(["--use-dla"])
        assert trt_config_args.use_dla

        builder, network = create_network()
        with builder, network, trt_config_args.create_config(builder, network=network) as config:
            assert config.default_device_type == trt.DeviceType.DLA
            assert config.DLA_core == 0
Example #7
0
    def test_no_opts(self):
        with util.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(["trt-config", "-o", template.name])

            builder, network = create_network()
            create_config = InvokeFromScript(template.name, "load_config")
            with builder, network, create_config(builder, network) as config:
                assert isinstance(config, trt.IBuilderConfig)
Example #8
0
    def test_tactic_replay(self, trt_config_args):
        with util.NamedTemporaryFile(suffix=".json") as f:
            trt_config_args.parse_args(["--tactic-replay", f.name])
            builder, network = create_network()

            with builder, network, trt_config_args.create_config(builder, network=network) as config:
                recorder = config.algorithm_selector
                assert recorder.make_func == TacticRecorder
                assert recorder.path == f.name
Example #9
0
    def test_precision_flags(self, trt_config_args, arg, flag):
        if flag == "TF32" and mod.version(trt.__version__) < mod.version("7.1"):
            pytest.skip("TF32 support was added in 7.1")

        trt_config_args.parse_args([arg])

        builder, network = create_network()
        with builder, network, trt_config_args.create_config(builder, network=network) as config:
            assert config.get_flag(getattr(trt.BuilderFlag, flag))
Example #10
0
    def test_tactics(self, trt_config_args, opt, cls):
        with util.NamedTemporaryFile("w+", suffix=".json") as f:
            if opt == "--load-tactics":
                TacticReplayData().save(f)

            trt_config_args.parse_args([opt, f.name])
            builder, network = create_network()
            with builder, network, trt_config_args.create_config(builder, network=network) as config:
                recorder = config.algorithm_selector
                assert recorder.make_func == cls
                assert recorder.path == f.name
Example #11
0
    def test_opts_basic(self):
        with util.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(
                ["trt-config", "--fp16", "--int8", "-o", template.name])

            builder, network = create_network()
            create_config = InvokeFromScript(template.name, "load_config")
            with builder, network, create_config(builder, network) as config:
                assert isinstance(config, trt.IBuilderConfig)
                assert config.get_flag(trt.BuilderFlag.FP16)
                assert config.get_flag(trt.BuilderFlag.INT8)
Example #12
0
    def test_build_engine_custom_network(self, engine_loader_args):
        engine_loader_args.parse_args([])

        builder, network = create_network()
        inp = network.add_input("input", dtype=trt.float32, shape=(1, 1))
        out = network.add_identity(inp).get_output(0)
        out.name = "output"
        network.mark_output(out)

        with builder, network, engine_loader_args.build_engine(
                network=(builder, network)) as engine:
            assert isinstance(engine, trt.ICudaEngine)
            assert len(engine) == 2
            assert engine[0] == "input"
            assert engine[1] == "output"
Example #13
0
    def test_legacy_calibrator_params(self, trt_config_args):
        quantile = 0.25
        regression_cutoff = 0.9
        trt_config_args.parse_args([
            "--int8", "--calibration-base-class=IInt8LegacyCalibrator",
            "--quantile",
            str(quantile), "--regression-cutoff",
            str(regression_cutoff)
        ])
        assert trt_config_args.quantile == quantile
        assert trt_config_args.regression_cutoff == regression_cutoff

        builder, network = create_network()
        with builder, network, trt_config_args.create_config(
                builder, network=network) as config:
            assert config.int8_calibrator.get_quantile() == quantile
            assert config.int8_calibrator.get_regression_cutoff(
            ) == regression_cutoff
Example #14
0
    def test_no_deps_profiles_int8(self):
        arg_group = ArgGroupTestHelper(TrtConfigArgs())
        arg_group.parse_args([
            "--trt-min-shapes=input:[1,25,25]",
            "--trt-opt-shapes=input:[2,25,25]",
            "--trt-max-shapes=input:[4,25,25]", "--int8"
        ])

        for (min_shapes, opt_shapes, max_shapes) in arg_group.profile_dicts:
            assert min_shapes["input"] == [1, 25, 25]
            assert opt_shapes["input"] == [2, 25, 25]
            assert max_shapes["input"] == [4, 25, 25]

        builder, network = create_network()

        with builder, network, arg_group.create_config(
                builder, network=network) as config:
            assert isinstance(config, trt.IBuilderConfig)
            # Unfortunately there is no API to check the contents of the profile in a config.
            # The checks above will have to do.
            assert config.num_optimization_profiles == 1
            assert config.get_flag(trt.BuilderFlag.INT8)
Example #15
0
 def test_tactic_sources(self, trt_config_args, opt, expected):
     trt_config_args.parse_args(opt)
     builder, network = create_network()
     with builder, network, trt_config_args.create_config(
             builder, network=network) as config:
         assert config.get_tactic_sources() == expected