Beispiel #1
0
    def test_no_opts(self):
        with util.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(["trt-config", "-o", template.name])

            builder, network = create_network()
            create_config = InvokeFromScript(template.name, "load_config")
            with builder, network, create_config(builder, network) as config:
                assert isinstance(config, trt.IBuilderConfig)
Beispiel #2
0
    def test_no_model_file(self):
        with tempfile.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(["trt-network", "-o", template.name])

            load_network = InvokeFromScript(template.name, "load_network")
            builder, network = load_network()
            with builder, network:
                assert isinstance(builder, trt.Builder)
                assert isinstance(network, trt.INetworkDefinition)
Beispiel #3
0
    def test_with_model_file(self):
        with tempfile.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(["trt-network", ONNX_MODELS["identity"].path, "-o", template.name])

            load_network = InvokeFromScript(template.name, "load_network")
            builder, network, parser = load_network()
            with builder, network, parser:
                assert isinstance(builder, trt.Builder)
                assert isinstance(network, trt.INetworkDefinition)
                assert isinstance(parser, trt.OnnxParser)
Beispiel #4
0
    def test_opts_basic(self):
        with util.NamedTemporaryFile("w+", suffix=".py") as template:
            run_polygraphy_template(
                ["trt-config", "--fp16", "--int8", "-o", template.name])

            builder, network = create_network()
            create_config = InvokeFromScript(template.name, "load_config")
            with builder, network, create_config(builder, network) as config:
                assert isinstance(config, trt.IBuilderConfig)
                assert config.get_flag(trt.BuilderFlag.FP16)
                assert config.get_flag(trt.BuilderFlag.INT8)