コード例 #1
0
ファイル: loader.py プロジェクト: celidos/TensorRT_study
    def __call__(self):
        """
        Builds a TensorRT engine.

        Returns:
            trt.ICudaEngine: The engine that was created.
        """
        # If network is a callable, then we own its return value
        ret, owns_network = misc.try_call(self._network)
        builder, network, parser = misc.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)
            else:
                provided = "Builder and Network" if parser is None else "Builder, Network, and Parser"
                G_LOGGER.verbose(
                    "{:} were provided directly instead of via a Callable. This loader will not assume ownership. "
                    "Please ensure that they are freed.".format(provided))

            config, owns_config = misc.try_call(self._config, builder, network)
            if owns_config:
                stack.enter_context(config)
            else:
                G_LOGGER.verbose(
                    "Builder configuration was provided directly instead of via a Callable. This loader will not assume "
                    "ownership. Please ensure it is freed.")

            network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs"
            G_LOGGER.super_verbose(
                lambda: ("Displaying TensorRT Network:\n" + trt_util.
                         str_from_network(network, mode=network_log_mode)))

            G_LOGGER.info("Building engine with configuration: {:}".format(
                trt_util.str_from_config(config)))

            if misc.version(trt.__version__) < misc.version("7.3"):
                engine = builder.build_engine(network, config)
            else:
                engine = func.invoke(
                    EngineFromBytes(
                        builder.build_serialized_network(network, config)))

            if hasattr(config.int8_calibrator, "free"):
                # Must go before engine check to ensure calibrator is freed on failures too.
                config.int8_calibrator.free()

            if not engine:
                G_LOGGER.critical(
                    "Invalid Engine. Please ensure the engine was built correctly"
                )
            return engine
コード例 #2
0
ファイル: test_runner.py プロジェクト: celidos/TensorRT_study
 def test_multiple_profiles(self):
     model = ONNX_MODELS["dynamic_identity"]
     shapes = [(1, 2, 4, 4), (1, 2, 8, 8), (1, 2, 16, 16)]
     network_loader = NetworkFromOnnxBytes(model.loader)
     profiles = [
         Profile().add("X", (1, 2, 1, 1), (1, 2, 2, 2), (1, 2, 4, 4)),
         Profile().add("X", *shapes),
     ]
     config_loader = CreateConfig(profiles=profiles)
     with TrtRunner(EngineFromNetwork(network_loader,
                                      config_loader)) as runner:
         if misc.version(trt.__version__) < misc.version("7.3"):
             runner.context.active_optimization_profile = 1
         else:
             runner.context.set_optimization_profile_async(
                 1, runner.stream.address())
         for shape in shapes:
             model.check_runner(runner, {"X": shape})
コード例 #3
0
    def test_calibrator_outside_polygraphy(self, identity_builder_network):
        builder, network = identity_builder_network
        NUM_BATCHES = 2

        def generate_data():
            for item in [np.ones(
                (1, 1, 2, 2), dtype=np.float32)] * NUM_BATCHES:
                yield {"x": item}

        calibrator = Calibrator(generate_data())

        config = builder.create_builder_config()
        config.set_flag(trt.BuilderFlag.INT8)
        config.int8_calibrator = calibrator

        if misc.version(trt.__version__) < misc.version("7.3"):
            engine = builder.build_engine(network, config)
        else:
            engine = func.invoke(
                EngineFromBytes(
                    builder.build_serialized_network(network, config)))

        with engine:
            assert engine
コード例 #4
0
    def __call__(self):
        """
        Parses an ONNX model.

        Returns:
            (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                    A TensorRT network, as well as the builder used to create it, and the parser
                    used to populate it.
        """
        builder, network, parser = super().__call__()
        parser.parse(misc.try_call(self._model_bytes)[0])
        trt_util.check_onnx_parser_errors(parser)
        return builder, network, parser


if misc.version(trt.__version__) >= misc.version("7.1"):
    class NetworkFromOnnxPath(BaseNetworkFromOnnx):
        def __init__(self, path, explicit_precision=None):
            """
            Functor that parses an ONNX model to create a trt.INetworkDefinition.
            This loader supports models with weights stored in an external location.

            Args:
                path (str): The path from which to load the model.
            """
            super().__init__(explicit_precision)
            self.path = path


        def __call__(self):
            """