Ejemplo n.º 1
0
            def load_from_cache():
                if self._cache is None or not util.get_file_size(self._cache):
                    return None

                try:
                    return util.load_file(self._cache, description="calibration cache")
                except:
                    G_LOGGER.warning("Could not read from calibration cache: {:}".format(self._cache))
                    return None
Ejemplo n.º 2
0
            def load_from_cache():
                if self._cache is None or not util.get_file_size(self._cache):
                    return None

                try:
                    return util.load_file(self._cache,
                                          description="calibration cache")
                except Exception as err:
                    G_LOGGER.error(
                        "Could not read from calibration cache: {:}\nNote: Error was: {:}"
                        .format(self._cache, err))
                    return None
Ejemplo n.º 3
0
    def parse(self, args):
        trt_min_shapes = util.default(args_util.get(args, "trt_min_shapes"),
                                      [])
        trt_max_shapes = util.default(args_util.get(args, "trt_max_shapes"),
                                      [])
        trt_opt_shapes = util.default(args_util.get(args, "trt_opt_shapes"),
                                      [])

        default_shapes = TensorMetadata()
        if self.model_args is not None:
            assert hasattr(self.model_args, "input_shapes"
                           ), "ModelArgs must be parsed before TrtConfigArgs!"
            default_shapes = self.model_args.input_shapes

        self.profile_dicts = parse_profile_shapes(default_shapes,
                                                  trt_min_shapes,
                                                  trt_opt_shapes,
                                                  trt_max_shapes)

        workspace = args_util.get(args, "workspace")
        self.workspace = int(workspace) if workspace is not None else workspace

        self.tf32 = args_util.get(args, "tf32")
        self.fp16 = args_util.get(args, "fp16")
        self.int8 = args_util.get(args, "int8")
        self.strict_types = args_util.get(args, "strict_types")
        self.restricted = args_util.get(args, "restricted")

        self.calibration_cache = args_util.get(args, "calibration_cache")
        calib_base = args_util.get(args, "calibration_base_class")
        self.calibration_base_class = None
        if calib_base is not None:
            calib_base = safe(assert_identifier(calib_base))
            self.calibration_base_class = inline(
                safe("trt.{:}", inline(calib_base)))

        self.quantile = args_util.get(args, "quantile")
        self.regression_cutoff = args_util.get(args, "regression_cutoff")

        self.sparse_weights = args_util.get(args, "sparse_weights")
        self.timing_cache = args_util.get(args, "timing_cache")

        tactic_replay = args_util.get(args, "tactic_replay")
        self.load_tactics = args_util.get(args, "load_tactics")
        self.save_tactics = args_util.get(args, "save_tactics")
        if tactic_replay is not None:
            mod.warn_deprecated("--tactic-replay",
                                "--save-tactics or --load-tactics",
                                remove_in="0.35.0")
            G_LOGGER.warning(
                "--tactic-replay is deprecated. Use either --save-tactics or --load-tactics instead."
            )
            if os.path.exists(
                    tactic_replay) and util.get_file_size(tactic_replay) > 0:
                self.load_tactics = tactic_replay
            else:
                self.save_tactics = tactic_replay

        tactic_sources = args_util.get(args, "tactic_sources")
        self.tactic_sources = None
        if tactic_sources is not None:
            self.tactic_sources = []
            for source in tactic_sources:
                source = safe(assert_identifier(source.upper()))
                source_str = safe("trt.TacticSource.{:}", inline(source))
                self.tactic_sources.append(inline(source_str))

        self.trt_config_script = args_util.get(args, "trt_config_script")
        self.trt_config_func_name = args_util.get(args, "trt_config_func_name")