示例#1
0
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = args_util.get(args, "runners", default=[])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option".format(args.model_file)
            )
示例#2
0
    def parse(self, args):
        self.op = args_util.get(args, "op")
        self.name = args_util.get(args, "name")

        self.attrs = args_util.parse_dict_with_default(args_util.get(args, "attrs"), sep="=")
        self.inputs = args_util.get(args, "inputs")
        self.outputs = args_util.get(args, "outputs")
示例#3
0
    def parse(self, args):
        self.do_shape_inference = args_util.get(args, "do_shape_inference")
        self.force_fallback = args_util.get(args, "force_fallback_shape_inference")

        # No point is running ONNX shape inference if we're going to use fallback inference.
        if self.force_fallback:
            self.do_shape_inference = False
示例#4
0
    def parse(self, args):
        self.verbosity_count = args_util.get(args, "verbose") - args_util.get(args, "quiet")
        self.silent = args_util.get(args, "silent")
        self.log_format = args_util.get(args, "log_format", default=[])
        self.log_file = args_util.get(args, "log_file")

        # Enable logger settings immediately on parsing.
        self.get_logger()
示例#5
0
    def parse(self, args):
        self.path = args_util.get(args, "save_onnx")
        save_external_data = args_util.get(args, "save_external_data")
        if save_external_data is not None:
            save_external_data = save_external_data[0] or ""
        self.save_external_data = save_external_data

        self.size_threshold = args_util.parse_num_bytes(args_util.get(args, "external_data_size_threshold"))
        self.all_tensors_to_one_file = args_util.get(args, "all_tensors_to_one_file")
示例#6
0
文件: model.py 项目: clayne/TensorRT
    def parse(self, args):
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = util.default(args_util.get(args, "runners"), [])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.exit(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))

        if args_util.get(args, "input_shapes"):
            self.input_shapes = args_util.parse_meta(
                args_util.get(args, "input_shapes"),
                includes_dtype=False)  # TensorMetadata
        else:
            self.input_shapes = TensorMetadata()

        self.model_file = args_util.get(args, "model_file")

        if self.model_file:
            G_LOGGER.verbose("Model: {:}".format(self.model_file))
            if not os.path.exists(self.model_file):
                G_LOGGER.warning("Model path does not exist: {:}".format(
                    self.model_file))
            self.model_file = os.path.abspath(self.model_file)

        model_type_str = util.default(self._model_type, determine_model_type())
        self.model_type = ModelArgs.ModelType(
            model_type_str) if model_type_str else None

        if self.model_type == "trt-network-script" and (
                not self.model_file or not self.model_file.endswith(".py")):
            G_LOGGER.exit(
                "TensorRT network scripts must exist and have '.py' extensions. "
                "Note: Provided network script path was: {:}".format(
                    self.model_file))
示例#7
0
    def parse(self, args):
        self.iter_artifact = args_util.get(args, "iter_artifact")

        if self.iter_artifact and os.path.exists(self.iter_artifact):
            G_LOGGER.critical(
                "{:} already exists, refusing to overwrite.\n"
                "Please specify a different path for the intermediate artifact with "
                "--intermediate-artifact".format(self.iter_artifact)
            )

        self.artifacts = util.default(args_util.get(args, "artifacts"), [])
        self.output = args_util.get(args, "artifacts_dir")
        self.show_output = args_util.get(args, "show_output")
        self.remove_intermediate = args_util.get(args, "remove_intermediate")
        self.fail_codes = args_util.get(args, "fail_codes")
        self.ignore_fail_codes = args_util.get(args, "ignore_fail_codes")

        self.fail_regexes = None
        fail_regex = args_util.get(args, "fail_regex")
        if fail_regex is not None:
            self.fail_regexes = []
            for regex in fail_regex:
                self.fail_regexes.append(re.compile(regex))

        if self.artifacts and not self.output:
            G_LOGGER.critical(
                "An output directory must be specified if artifacts are enabled! "
                "Note: Artifacts specified were: {:}".format(self.artifacts)
            )

        if not self.artifacts and self._prefer_artifacts:
            G_LOGGER.warning(
                "`--artifacts` was not specified; No artifacts will be stored during this run! "
                "Is this what you intended?"
            )

        self.iteration_info = args_util.get(args, "iteration_info")

        self.check = args_util.get(args, "check")

        self.start_date = time.strftime("%x").replace("/", "-")
        self.start_time = time.strftime("%X").replace(":", "-")
示例#8
0
    def parse(self, args):
        def omit_none_tuple(tup):
            if all([elem is None for elem in tup]):
                return None
            return tup

        self.seed = args_util.get(args, "seed")

        self.int_range = omit_none_tuple(tup=(args_util.get(args, "int_min"),
                                              args_util.get(args, "int_max")))
        self.float_range = omit_none_tuple(
            tup=(args_util.get(args, "float_min"),
                 args_util.get(args, "float_max")))
        if self.int_range or self.float_range:
            G_LOGGER.warning(
                "The --int-min/--int-max and --float-min/--float-max options are deprecated.\n"
                "Please use `--val-range` instead, which allows you to specify per-input data ranges."
            )

        self.val_range = args_util.parse_dict_with_default(args_util.get(
            args, "val_range"),
                                                           cast_to=tuple)
        if self.val_range is not None:
            for name, vals in self.val_range.items():
                if len(vals) != 2:
                    G_LOGGER.critical(
                        "In --val-range, for input: {:}, expected to receive exactly 2 values, but received {:}.\n"
                        "Note: Option was parsed as: input: {:}, range: {:}".
                        format(name, len(vals), name, vals))

                if any(not isinstance(elem, numbers.Number) for elem in vals):
                    G_LOGGER.critical(
                        "In --val-range, for input: {:}, one or more elements of the range could not be parsed as a number.\n"
                        "Note: Option was parsed as: input: {:}, range: {:}".
                        format(name, name, vals))

        self.iterations = args_util.get(args, "iterations")

        self.load_inputs = args_util.get(args, "load_inputs")
        self.data_loader_script = args_util.get(args, "data_loader_script")
        self.data_loader_func_name = args_util.get(args,
                                                   "data_loader_func_name")
示例#9
0
 def parse(self, args):
     self.ckpt = args_util.get(args, "ckpt")
     self.outputs = args_util.get_outputs(args, "tf_outputs")
     self.save_pb = args_util.get(args, "save_pb")
     self.save_tensorboard = args_util.get(args, "save_tensorboard")
     self.freeze_graph = args_util.get(args, "freeze_graph")
     self.tftrt = args_util.get(args, "tftrt")
     self.minimum_segment_size = args_util.get(args, "minimum_segment_size")
     self.dynamic_op = args_util.get(args, "dynamic_op")
示例#10
0
 def parse(self, args):
     self.trt_outputs = args_util.get(args, "trt_outputs")
     self.caffe_model = args_util.get(args, "caffe_model")
     self.batch_size = args_util.get(args, "batch_size")
     self.save_uff = args_util.get(args, "save_uff")
     self.uff_order = args_util.get(args, "uff_order")
     self.preprocessor = args_util.get(args, "preprocessor")
示例#11
0
    def parse(self, args):
        def omit_none_tuple(tup):
            if all([elem is None for elem in tup]):
                return None
            return tup

        self.seed = args_util.get(args, "seed")

        self.int_range = omit_none_tuple(tup=(args_util.get(args, "int_min"),
                                              args_util.get(args, "int_max")))
        self.float_range = omit_none_tuple(
            tup=(args_util.get(args, "float_min"),
                 args_util.get(args, "float_max")))
        self.val_range = args_util.parse_dict_with_default(args_util.get(
            args, "val_range"),
                                                           cast_to=tuple)

        self.iterations = args_util.get(args, "iterations")

        self.load_inputs = args_util.get(args, "load_inputs")
        self.data_loader_script = args_util.get(args, "data_loader_script")
        self.data_loader_func_name = args_util.get(args,
                                                   "data_loader_func_name")
示例#12
0
    def parse(self, args):
        self.no_shape_check = args_util.get(args, "no_shape_check")
        self.rtol = args_util.parse_dict_with_default(
            args_util.get(args, "rtol"))
        self.atol = args_util.parse_dict_with_default(
            args_util.get(args, "atol"))
        self.validate = args_util.get(args, "validate")
        self.load_results = args_util.get(args, "load_results")
        self.fail_fast = args_util.get(args, "fail_fast")
        self.top_k = args_util.parse_dict_with_default(
            args_util.get(args, "top_k"))
        self.check_error_stat = args_util.parse_dict_with_default(
            args_util.get(args, "check_error_stat"))
        if self.check_error_stat:
            VALID_CHECK_ERROR_STATS = ["max", "mean", "median", "elemwise"]
            for stat in self.check_error_stat.values():
                if stat not in VALID_CHECK_ERROR_STATS:
                    G_LOGGER.critical(
                        "Invalid choice for check_error_stat: {:}.\n"
                        "Note: Valid choices are: {:}".format(
                            stat, VALID_CHECK_ERROR_STATS))

        # FIXME: This should be a proper dependency from a RunnerArgs
        self.runners = util.default(args_util.get(args, "runners"), [])
示例#13
0
 def parse(self, args):
     self.warm_up = args_util.get(args, "warm_up")
     self.use_subprocess = args_util.get(args, "use_subprocess")
     self.save_inputs = args_util.get(args, "save_inputs")
     self.save_results = args_util.get(args, "save_results")
示例#14
0
    def parse(self, args):
        trt_min_shapes = util.default(args_util.get(args, "trt_min_shapes"),
                                      [])
        trt_max_shapes = util.default(args_util.get(args, "trt_max_shapes"),
                                      [])
        trt_opt_shapes = util.default(args_util.get(args, "trt_opt_shapes"),
                                      [])

        default_shapes = TensorMetadata()
        if self.model_args is not None:
            assert hasattr(self.model_args, "input_shapes"
                           ), "ModelArgs must be parsed before TrtConfigArgs!"
            default_shapes = self.model_args.input_shapes

        self.profile_dicts = parse_profile_shapes(default_shapes,
                                                  trt_min_shapes,
                                                  trt_opt_shapes,
                                                  trt_max_shapes)

        workspace = args_util.get(args, "workspace")
        self.workspace = int(workspace) if workspace is not None else workspace

        self.tf32 = args_util.get(args, "tf32")
        self.fp16 = args_util.get(args, "fp16")
        self.int8 = args_util.get(args, "int8")
        self.strict_types = args_util.get(args, "strict_types")
        self.restricted = args_util.get(args, "restricted")

        self.calibration_cache = args_util.get(args, "calibration_cache")
        calib_base = args_util.get(args, "calibration_base_class")
        self.calibration_base_class = None
        if calib_base is not None:
            calib_base = safe(assert_identifier(calib_base))
            self.calibration_base_class = inline(
                safe("trt.{:}", inline(calib_base)))

        self.quantile = args_util.get(args, "quantile")
        self.regression_cutoff = args_util.get(args, "regression_cutoff")

        self.sparse_weights = args_util.get(args, "sparse_weights")
        self.timing_cache = args_util.get(args, "timing_cache")

        tactic_replay = args_util.get(args, "tactic_replay")
        self.load_tactics = args_util.get(args, "load_tactics")
        self.save_tactics = args_util.get(args, "save_tactics")
        if tactic_replay is not None:
            mod.warn_deprecated("--tactic-replay",
                                "--save-tactics or --load-tactics",
                                remove_in="0.35.0")
            G_LOGGER.warning(
                "--tactic-replay is deprecated. Use either --save-tactics or --load-tactics instead."
            )
            if os.path.exists(
                    tactic_replay) and util.get_file_size(tactic_replay) > 0:
                self.load_tactics = tactic_replay
            else:
                self.save_tactics = tactic_replay

        tactic_sources = args_util.get(args, "tactic_sources")
        self.tactic_sources = None
        if tactic_sources is not None:
            self.tactic_sources = []
            for source in tactic_sources:
                source = safe(assert_identifier(source.upper()))
                source_str = safe("trt.TacticSource.{:}", inline(source))
                self.tactic_sources.append(inline(source_str))

        self.trt_config_script = args_util.get(args, "trt_config_script")
        self.trt_config_func_name = args_util.get(args, "trt_config_func_name")
示例#15
0
 def parse(self, args):
     self.outputs = args_util.get_outputs(args, "trt_outputs")
     self.explicit_precision = args_util.get(args, "explicit_precision")
     self.exclude_outputs = args_util.get(args, "trt_exclude_outputs")
     self.trt_network_func_name = args_util.get(args, "trt_network_func_name")
示例#16
0
 def parse(self, args):
     self.opset = args_util.get(args, "opset")
     self.fold_constant = False if args_util.get(
         args, "no_const_folding") else None
示例#17
0
文件: runner.py 项目: clayne/TensorRT
 def parse(self, args):
     self.timeline_path = args_util.get(args, "save_timeline")
示例#18
0
文件: config.py 项目: clayne/TensorRT
 def parse(self, args):
     self.gpu_memory_fraction = args_util.get(args, "gpu_memory_fraction")
     self.allow_growth = args_util.get(args, "allow_growth")
     self.xla = args_util.get(args, "xla")
示例#19
0
    def parse(self, args):
        self.trt_outputs = args_util.get_outputs(args, "trt_outputs")
        self.caffe_model = args_util.get(args, "caffe_model")
        self.batch_size = args_util.get(args, "batch_size")
        self.save_uff = args_util.get(args, "save_uff")
        self.uff_order = args_util.get(args, "uff_order")
        self.preprocessor = args_util.get(args, "preprocessor")

        self.calibration_cache = args_util.get(args, "calibration_cache")
        calib_base = args_util.get(args, "calibration_base_class")
        self.calibration_base_class = None
        if calib_base is not None:
            calib_base = safe(assert_identifier(calib_base))
            self.calibration_base_class = inline(safe("trt.{:}", inline(calib_base)))

        self.quantile = args_util.get(args, "quantile")
        self.regression_cutoff = args_util.get(args, "regression_cutoff")

        self.use_dla = args_util.get(args, "use_dla")
        self.allow_gpu_fallback = args_util.get(args, "allow_gpu_fallback")
示例#20
0
 def parse(self, args):
     self.outputs = args_util.get_outputs(args, "onnx_outputs")
     self.exclude_outputs = args_util.get(args, "onnx_exclude_outputs")
     self.load_external_data = args_util.get(args, "load_external_data")
示例#21
0
 def parse(self, args):
     self.path = args_util.get(args, "save_engine")
示例#22
0
 def parse(self, args):
     self.plugins = args_util.get(args, "plugins")
示例#23
0
 def parse(self, args):
     self.path = args_util.get(args, "save_onnx")
     self.save_external_data = args_util.get(args, "save_external_data")