Beispiel #1
0
 def _create_converter(self, trt_convert_params: trt.TrtConversionParams):
     return trt.TrtGraphConverterV2(
         input_saved_model_dir=self.model_config.saved_model_dir,
         input_saved_model_tags=self.model_config.saved_model_tags,
         input_saved_model_signature_key=(
             self.model_config.saved_model_signature_key),
         **trt_convert_params._asdict())
Beispiel #2
0
    def _run_impl(
        self,
        default_trt_converter_params: trt.TrtConversionParams,
        trt_converter_params_updater: Callable[
            [trt.TrtConversionParams], Iterable[trt.TrtConversionParams]],
    ):
        """Runs all sample models based on a key varying parameter."""
        for model_config in self._configs:
            trt_convert_params = default_trt_converter_params._replace(
                max_batch_size=model_config.default_batch_size)
            # Load, compile and runs the models.
            manager = self._model_handler_manager_cls(
                model_config=model_config,
                default_trt_convert_params=trt_convert_params,
                trt_convert_params_updater=trt_converter_params_updater)
            inputs = manager.generate_random_inputs()
            # As all the data are randomly generated, directly use inference data as
            # calibration data to produce reliable dynamic ranges.
            manager.convert(inputs)
            result_collection = manager.run(inputs)

            logging.info("Model information: %s", repr(manager))
            for result in result_collection.results:
                logging.info(
                    "TensorRT parameters: %s", result.trt_convert_params
                    or "Not a TensorRT Model")
                logging.info("Mean latency: %f ms", _get_mean_latency(result))
Beispiel #3
0
 def trt_converter_params_updater(params: trt.TrtConversionParams):
     for precision_mode in [
             trt.TrtPrecisionMode.FP32, trt.TrtPrecisionMode.FP16,
             trt.TrtPrecisionMode.INT8
     ]:
         yield params._replace(
             precision_mode=precision_mode,
             use_calibration=(
                 precision_mode == trt.TrtPrecisionMode.INT8))
Beispiel #4
0
 def trt_converter_params_updater(params: trt.TrtConversionParams):
     for precision_mode in [
             trt.TrtPrecisionMode.FP32, trt.TrtPrecisionMode.FP16
     ]:
         yield params._replace(precision_mode=precision_mode)