Ejemplo n.º 1
0
    def parse(self, args):
        self.verbosity_count = tools_util.get(args, "verbose")
        self.silent = tools_util.get(args, "silent")
        self.log_format = misc.default_value(tools_util.get(args, "log_format"), [])

        # Enable logger settings immediately on parsing.
        self.get_logger()
Ejemplo n.º 2
0
 def parse(self, args):
     self.save_onnx = tools_util.get(args, "save_onnx")
     if hasattr(args, "no_shape_inference"):
         self.do_shape_inference = None if tools_util.get(
             args, "no_shape_inference") else True
     else:
         self.do_shape_inference = tools_util.get(args, "shape_inference")
     self.outputs = tools_util.get_outputs(args, "onnx_outputs")
     self.exclude_outputs = tools_util.get(args, "onnx_exclude_outputs")
Ejemplo n.º 3
0
        def determine_model_type():
            if tools_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if tools_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = misc.default_value(tools_util.get(args, "runners"), [])
            if tools_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                ext_mapping = {
                    ".hdf5": "keras",
                    ".uff": "uff",
                    ".prototxt": "caffe",
                    ".onnx": "onnx",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                return use_ext(ext_mapping) or "frozen"
            else:
                # When no framework is provided, some extensions can be ambiguous
                ext_mapping = {
                    ".hdf5": "keras",
                    ".graphdef": "frozen",
                    ".onnx": "onnx",
                    ".uff": "uff",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                model_type = use_ext(ext_mapping)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))
Ejemplo n.º 4
0
 def parse(self, args):
     self.ckpt = tools_util.get(args, "ckpt")
     self.outputs = tools_util.get_outputs(args, "tf_outputs")
     self.save_pb = tools_util.get(args, "save_pb")
     self.save_tensorboard = tools_util.get(args, "save_tensorboard")
     self.freeze_graph = tools_util.get(args, "freeze_graph")
     self.tftrt = tools_util.get(args, "tftrt")
     self.minimum_segment_size = tools_util.get(args,
                                                "minimum_segment_size")
     self.dynamic_op = tools_util.get(args, "dynamic_op")
Ejemplo n.º 5
0
 def parse(self, args):
     self.trt_outputs = tools_util.get(args, "trt_outputs")
     self.caffe_model = tools_util.get(args, "caffe_model")
     self.batch_size = tools_util.get(args, "batch_size")
     self.save_uff = tools_util.get(args, "save_uff")
     self.uff_order = tools_util.get(args, "uff_order")
     self.preprocessor = tools_util.get(args, "preprocessor")
Ejemplo n.º 6
0
    def parse(self, args):
        def omit_none_tuple(tup):
            if all([elem is None for elem in tup]):
                return None
            return tup

        self.seed = tools_util.get(args, "seed")
        self.int_range = omit_none_tuple(tup=(tools_util.get(args, "int_min"),
                                              tools_util.get(args, "int_max")))
        self.float_range = omit_none_tuple(
            tup=(tools_util.get(args, "float_min"),
                 tools_util.get(args, "float_max")))
        self.iterations = tools_util.get(args, "iterations")
        self.load_inputs = tools_util.get(args, "load_inputs")
Ejemplo n.º 7
0
    def parse(self, args):
        def parse_tol(tol_arg):
            if tol_arg is None:
                return tol_arg

            tol_map = {}
            for output_tol_arg in tol_arg:
                out_name, _, tol = output_tol_arg.rpartition(",")
                tol_map[out_name] = float(tol)
            return tol_map

        self.no_shape_check = tools_util.get(args, "no_shape_check")
        self.rtol = parse_tol(tools_util.get(args, "rtol"))
        self.atol = parse_tol(tools_util.get(args, "atol"))
        self.validate = tools_util.get(args, "validate")
        self.load_results = tools_util.get(args, "load_results")
        self.fail_fast = tools_util.get(args, "fail_fast")
        self.top_k = tools_util.get(args, "top_k")
        # FIXME: This should be a proper dependency from a RunnerArgs
        self.runners = tools_util.get(args, "runners")
Ejemplo n.º 8
0
    def run(self, args):
        if self.makers[TrtLoaderArgs].network_api and not tools_util.get(
                args, "gen_script"):
            G_LOGGER.critical(
                "Cannot use the --network-api option if --gen/--gen-script is not being used."
            )
        elif self.makers[
                TrtLoaderArgs].network_api and "trt" not in args.runners:
            args.runners.append("trt")

        if self.makers[
                ModelArgs].model_file is None and args.runners and self.makers[
                    TrtLoaderArgs].network_api is None:
            G_LOGGER.critical(
                "One or more runners was specified, but no model file was provided. Make sure you've specified the model path, "
                "and also that it's not being consumed as an argument for another parameter"
            )

        misc.log_module_info(polygraphy)

        script = self.build_script(args)

        if args.gen_script:
            with args.gen_script:
                args.gen_script.write(script)

                path = args.gen_script.name
                # Somehow, piping fools isatty, e.g. `polygraphy run --gen-script - | cat`
                if not args.gen_script.isatty() and path not in [
                        "<stdout>", "<stderr>"
                ]:
                    G_LOGGER.info("Writing script to: {:}".format(path))
                    # Make file executable
                    os.chmod(path, os.stat(path).st_mode | 0o111)
        else:
            exec(script)

        return 0
Ejemplo n.º 9
0
    def parse(self, args):
        self.plugins = tools_util.get(args, "plugins")
        self.outputs = tools_util.get_outputs(args, "trt_outputs")
        self.network_api = tools_util.get(args, "network_api")
        self.ext = tools_util.get(args, "ext")
        self.explicit_precision = tools_util.get(args, "explicit_precision")
        self.exclude_outputs = tools_util.get(args, "trt_exclude_outputs")

        self.trt_min_shapes = misc.default_value(
            tools_util.get(args, "trt_min_shapes"), [])
        self.trt_max_shapes = misc.default_value(
            tools_util.get(args, "trt_max_shapes"), [])
        self.trt_opt_shapes = misc.default_value(
            tools_util.get(args, "trt_opt_shapes"), [])

        workspace = tools_util.get(args, "workspace")
        self.workspace = int(workspace) if workspace is not None else workspace

        self.tf32 = tools_util.get(args, "tf32")
        self.fp16 = tools_util.get(args, "fp16")
        self.int8 = tools_util.get(args, "int8")

        self.calibration_cache = tools_util.get(args, "calibration_cache")
        self.strict_types = tools_util.get(args, "strict_types")
Ejemplo n.º 10
0
 def parse(self, args):
     self.opset = tools_util.get(args, "opset")
     self.fold_constant = False if tools_util.get(
         args, "no_const_folding") else None
Ejemplo n.º 11
0
 def parse(self, args):
     self.gpu_memory_fraction = tools_util.get(args, "gpu_memory_fraction")
     self.allow_growth = tools_util.get(args, "allow_growth")
     self.xla = tools_util.get(args, "xla")
Ejemplo n.º 12
0
 def parse(self, args):
     self.warm_up = tools_util.get(args, "warm_up")
     self.use_subprocess = tools_util.get(args, "use_subprocess")
     self.save_inputs = tools_util.get(args, "save_inputs")
     self.save_results = tools_util.get(args, "save_results")
Ejemplo n.º 13
0
    def parse(self, args):
        def determine_model_type():
            if tools_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if tools_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = misc.default_value(tools_util.get(args, "runners"), [])
            if tools_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                ext_mapping = {
                    ".hdf5": "keras",
                    ".uff": "uff",
                    ".prototxt": "caffe",
                    ".onnx": "onnx",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                return use_ext(ext_mapping) or "frozen"
            else:
                # When no framework is provided, some extensions can be ambiguous
                ext_mapping = {
                    ".hdf5": "keras",
                    ".graphdef": "frozen",
                    ".onnx": "onnx",
                    ".uff": "uff",
                    ".engine": "engine",
                    ".plan": "engine"
                }
                model_type = use_ext(ext_mapping)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))

        if tools_util.get(args, "model_file"):
            G_LOGGER.verbose("Model: {:}".format(args.model_file))
            if not os.path.exists(args.model_file):
                G_LOGGER.warning("Model path does not exist: {:}".format(
                    args.model_file))
            args.model_file = os.path.abspath(args.model_file)

        if tools_util.get(args, "input_shapes"):
            self.input_shapes = tools_util.parse_meta(
                tools_util.get(args, "input_shapes"),
                includes_dtype=False)  # TensorMetadata
        else:
            self.input_shapes = TensorMetadata()

        self.model_file = args.model_file
        self.model_type = misc.default_value(self._model_type,
                                             determine_model_type())
Ejemplo n.º 14
0
 def parse(self, args):
     self.save_engine = tools_util.get(args, "save_engine")
Ejemplo n.º 15
0
 def parse(self, args):
     self.timeline_path = tools_util.get(args, "save_timeline")