예제 #1
0
    def __call__(self):
        """
        Modifies a TensorRT ``INetworkDefinition``.

        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = misc.try_call(self._network)
        builder, network, parser = misc.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(
                    misc.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser
예제 #2
0
 def __call__(self):
     with misc.FreeOnException(
             func.invoke(
                 CreateNetwork(
                     explicit_precision=self.explicit_precision,
                     explicit_batch=self.explicit_batch))) as (builder,
                                                               network):
         parser = trt.OnnxParser(network, trt_util.TRT_LOGGER)
         return builder, network, parser
예제 #3
0
    def __call__(self, builder, network):
        """
        Creates a TensorRT IBuilderConfig that can be used by the EngineFromNetwork.

        Args:
            builder (trt.Builder):
                    The TensorRT builder to use to create the configuration.
            network (trt.INetworkDefinition):
                    The TensorRT network for which to create the config. The network is used to
                    automatically create a default optimization profile if none are provided.

        Returns:
            trt.IBuilderConfig: The TensorRT builder configuration.
        """
        with misc.FreeOnException([builder.create_builder_config()
                                   ]) as (config, ):
            calibration_profile = None
            for profile in self.profiles:
                calibration_profile = trt_util.build_profile(
                    builder, network, profile)
                config.add_optimization_profile(calibration_profile)
            if not self.profiles:
                calibration_profile = trt_util.build_default_profile(
                    builder, network)
                config.add_optimization_profile(calibration_profile)

            if self.profiles:
                G_LOGGER.info("Configuring with profiles: {:}".format(
                    self.profiles))

            config.max_workspace_size = int(self.max_workspace_size)

            if self.strict_types:
                config.set_flag(trt.BuilderFlag.STRICT_TYPES)
            if not self.tf32:
                with contextlib.suppress(AttributeError):
                    config.clear_flag(trt.BuilderFlag.TF32)
            if self.fp16:
                config.set_flag(trt.BuilderFlag.FP16)
            if self.int8:
                config.set_flag(trt.BuilderFlag.INT8)
                if not network.has_explicit_precision:
                    if self.calibrator is not None:
                        input_metadata = trt_util.get_input_metadata_from_profile(
                            calibration_profile, network)
                        with contextlib.suppress(AttributeError):
                            self.calibrator.reset(input_metadata)
                        config.int8_calibrator = self.calibrator
                        with contextlib.suppress(AttributeError):
                            config.set_calibration_profile(calibration_profile)
                    else:
                        G_LOGGER.warning(
                            "Network does not have explicit precision and no calibrator was provided. Please ensure "
                            "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode."
                        )
            return config
예제 #4
0
    def __call__(self):
        from polygraphy.backend.onnx import util as onnx_util

        with misc.FreeOnException(super().__call__()) as (builder, network,
                                                          parser):
            onnx_model, _ = misc.try_call(self.onnx_loader)
            _, shape = list(
                onnx_util.get_input_metadata(onnx_model.graph).values())[0]

            parser.parse(onnx_model.SerializeToString())
            trt_util.check_onnx_parser_errors(parser)

            return builder, network, parser, shape[0]
예제 #5
0
    def __call__(self):
        """
        Parses an ONNX model.

        Returns:
            (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                    A TensorRT network, as well as the builder used to create it, and the parser
                    used to populate it.
        """
        with misc.FreeOnException(super().__call__()) as (builder, network,
                                                          parser):
            parser.parse(misc.try_call(self._model_bytes)[0])
            trt_util.check_onnx_parser_errors(parser)
            return builder, network, parser
예제 #6
0
    def __call__(self):
        """
        Saves an engine to the provided path.

        Returns:
            trt.ICudaEngine: The engine that was saved.
        """
        engine, owns_engine = misc.try_call(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(misc.FreeOnException([engine]))

            misc.lazy_write(contents=lambda: engine.serialize(),
                            path=self.path)
            return engine
예제 #7
0
        def __call__(self):
            """
            Parses an ONNX model from a file.

            Returns:
                (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                        A TensorRT network, as well as the builder used to create it, and the parser
                        used to populate it.
            """
            with misc.FreeOnException(super().__call__()) as (builder, network,
                                                              parser):
                # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for
                # potentially parsing any external weights.
                parser.parse_from_file(misc.try_call(self.path)[0])
                trt_util.check_onnx_parser_errors(parser)
                return builder, network, parser
예제 #8
0
    def __call__(self):
        """
        Creates an empty TensorRT network.

        Returns:
            (trt.Builder, trt.INetworkDefinition): The builder and empty network.
        """
        with misc.FreeOnException([trt.Builder(trt_util.TRT_LOGGER)
                                   ]) as (builder, ):
            network_flags = 0
            if self.explicit_batch:
                network_flags |= 1 << int(
                    trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
            if self.explicit_precision:
                network_flags |= 1 << int(
                    trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
            network = builder.create_network(flags=network_flags)
            if network is None:
                G_LOGGER.critical(
                    "Invalid network. See logging output above for details.")
            return builder, network