Ejemplo n.º 1
0
 def setup(self, args, network):
     try:
         self.until = int(args.until) - 1
     except:
         self.until = args.until
         if self.until not in ["good", "bad"]:
             G_LOGGER.exit("--until value must be an integer, 'good', or 'bad', but was: {:}".format(args.until))
Ejemplo n.º 2
0
    def run(self, args):
        if not args.convert_to:
            _, ext = os.path.splitext(args.output)
            if ext not in ModelArgs.EXT_MODEL_TYPE_MAPPING:
                G_LOGGER.exit(
                    "Could not automatically determine model type based on output path: {:}\n"
                    "Please specify the desired output format with --convert-to"
                    .format(args.output))
            convert_type = ModelArgs.ModelType(
                ModelArgs.EXT_MODEL_TYPE_MAPPING[ext])
        else:
            CONVERT_TO_MODEL_TYPE_MAPPING = {"onnx": "onnx", "trt": "engine"}
            convert_type = ModelArgs.ModelType(
                CONVERT_TO_MODEL_TYPE_MAPPING[args.convert_to])

        if convert_type.is_onnx():
            model = self.arg_groups[OnnxLoaderArgs].load_onnx()
            if args.fp_to_fp16:
                model = onnx_backend.convert_to_fp16(model)
            self.arg_groups[OnnxSaveArgs].save_onnx(model, args.output)
        elif convert_type.is_trt():
            with self.arg_groups[TrtEngineLoaderArgs].build_engine() as engine:
                self.arg_groups[TrtEngineSaveArgs].save_engine(
                    engine, args.output)
        else:
            G_LOGGER.exit(
                "Cannot convert to model type: {:}".format(convert_type))
Ejemplo n.º 3
0
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = util.default(args_util.get(args, "runners"), [])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.exit(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))
Ejemplo n.º 4
0
    def add_onnx_loader(self,
                        script,
                        disable_custom_outputs=None,
                        suffix=None):
        model_type = self.model_args.model_type
        if model_type.is_onnx():
            script.add_import(imports=["OnnxFromPath"],
                              frm="polygraphy.backend.onnx")
            loader_str = make_invocable(
                "OnnxFromPath",
                self.model_args.model_file,
                external_data_dir=self.load_external_data)
            loader_name = script.add_loader(loader_str,
                                            "load_onnx",
                                            suffix=suffix)
        elif model_type.is_tf():
            if self.tf2onnx_loader_args is None:
                G_LOGGER.exit(
                    "Could not load: {:}. Is it an ONNX model?".format(
                        self.model_args.model_file))
            loader_name = self.tf2onnx_loader_args.add_to_script(script)
        else:
            G_LOGGER.exit(
                "Model type: {:} cannot be converted to ONNX.".format(
                    model_type))

        loader_name = self._get_modify_onnx_loader(
            script, loader_name, disable_custom_outputs=disable_custom_outputs)

        if self.onnx_save_args is not None:
            loader_name = self.onnx_save_args.add_save_onnx(
                script, loader_name)

        return loader_name
Ejemplo n.º 5
0
 def pop_meta(name):
     nonlocal tensor_meta_arg
     tensor_meta_arg, _, val = tensor_meta_arg.rpartition(SEP)
     if not tensor_meta_arg:
         G_LOGGER.exit("Could not parse {:} from argument: {:}. Is it separated by a comma "
                             "(,) from the tensor name?".format(name, orig_tensor_meta_arg))
     if val.lower() == "auto":
         val = None
     return val
Ejemplo n.º 6
0
    def parse(self, args):
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = util.default(args_util.get(args, "runners"), [])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.exit(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option"
                .format(args.model_file))

        if args_util.get(args, "input_shapes"):
            self.input_shapes = args_util.parse_meta(
                args_util.get(args, "input_shapes"),
                includes_dtype=False)  # TensorMetadata
        else:
            self.input_shapes = TensorMetadata()

        self.model_file = args_util.get(args, "model_file")

        if self.model_file:
            G_LOGGER.verbose("Model: {:}".format(self.model_file))
            if not os.path.exists(self.model_file):
                G_LOGGER.warning("Model path does not exist: {:}".format(
                    self.model_file))
            self.model_file = os.path.abspath(self.model_file)

        model_type_str = util.default(self._model_type, determine_model_type())
        self.model_type = ModelArgs.ModelType(
            model_type_str) if model_type_str else None

        if self.model_type == "trt-network-script" and (
                not self.model_file or not self.model_file.endswith(".py")):
            G_LOGGER.exit(
                "TensorRT network scripts must exist and have '.py' extensions. "
                "Note: Provided network script path was: {:}".format(
                    self.model_file))
Ejemplo n.º 7
0
def assert_identifier(inp):
    """
    Checks if the argument can be a valid Python identifier.

    Raises a PolygraphyException if it can't.
    """
    if not inp.isidentifier():
        G_LOGGER.exit("This argument must be a valid identifier. "
                      "Provided argument cannot be a Python identifier: {:}".format(inp))
    return inp
Ejemplo n.º 8
0
    def run(self, args):
        func = None

        if self.arg_groups[ModelArgs].model_type.is_tf():
            func = self.inspect_tf

        if self.arg_groups[ModelArgs].model_type.is_onnx():
            func = self.inspect_onnx

        if self.arg_groups[ModelArgs].model_type.is_trt(
        ) or args.display_as == "trt":
            func = self.inspect_trt

        if func is None:
            G_LOGGER.exit(
                "Could not determine how to display this model. Maybe you need to specify --display-as?"
            )

        func(args)
Ejemplo n.º 9
0
    def run(self, args):
        try:
            until = int(args.until) - 1
        except:
            until = args.until
            if until not in ["good", "bad"]:
                G_LOGGER.exit("--until value must be an integer, 'good', or 'bad', but was: {:}".format(args.until))


        def stop(index, success):
            if until == "good":
                return success
            elif until == "bad":
                return not success

            return index >= until


        G_LOGGER.start("Starting iterations")

        num_passed = 0
        num_total = 0

        success = True
        MAX_COUNT = 100000 # We don't want to loop forever. This many iterations ought to be enough for anybody.
        for iteration in range(MAX_COUNT):
            G_LOGGER.start("RUNNING | Iteration {:}".format(iteration + 1))

            success = self.arg_groups[ArtifactSorterArgs].sort_artifacts(iteration + 1)

            num_total += 1
            if success:
                num_passed += 1

            if stop(iteration, success):
                break
        else:
            G_LOGGER.warning("Maximum number of iterations reached: {:}.\n"
                                "Iteration has been halted to prevent an infinite loop!".format(MAX_COUNT))

        G_LOGGER.finish("Finished {:} iteration(s) | Passed: {:}/{:} | Pass Rate: {:}%".format(
                            iteration + 1, num_passed, num_total, float(num_passed) * 100 / float(num_total)))
Ejemplo n.º 10
0
    def parse(self, args):
        self.iter_artifact = args_util.get(args, "iter_artifact")

        if self.iter_artifact and os.path.exists(self.iter_artifact):
            G_LOGGER.exit(
                "{:} already exists, refusing to overwrite.\n"
                "Please specify a different path for the intermediate artifact with "
                "--intermediate-artifact".format(self.iter_artifact))

        self.artifacts = util.default(args_util.get(args, "artifacts"), [])
        self.output = args_util.get(args, "artifacts_dir")
        self.show_output = args_util.get(args, "show_output")
        self.remove_intermediate = args_util.get(args, "remove_intermediate")
        self.fail_codes = args_util.get(args, "fail_codes")

        self.fail_regexes = None
        fail_regex = args_util.get(args, "fail_regex")
        if fail_regex is not None:
            self.fail_regexes = []
            for regex in fail_regex:
                self.fail_regexes.append(re.compile(regex))

        if self.artifacts and not self.output:
            G_LOGGER.exit(
                "An output directory must be specified if artifacts are enabled! "
                "Note: Artifacts specified were: {:}".format(self.artifacts))

        if not self.artifacts and self._prefer_artifacts:
            G_LOGGER.warning(
                "`--artifacts` was not specified; No artifacts will be stored during this run! "
                "Is this what you intended?")

        self.iteration_info = args_util.get(args, "iteration_info")

        self.check = args_util.get(args, "check")

        self.start_date = time.strftime("%x").replace("/", "-")
        self.start_time = time.strftime("%X").replace(":", "-")
Ejemplo n.º 11
0
    def setup(self, args, network):
        self.precision = {
            "fp32": trt.float32,
            "fp16": trt.float16
        }[args.precision]

        if self.precision == trt.float16 and not self.arg_groups[
                TrtConfigArgs].fp16:
            G_LOGGER.exit(
                "Cannot mark layers to run in fp16 if it is not enabled in the builder configuration.\n"
                "Please also specify `--fp16` as a command-line option")

        if self.precision == trt.float16 and not self.arg_groups[
                TrtConfigArgs].int8:
            G_LOGGER.warning(
                "Using fp16 as the higher precision, but fp16 is also the lowest precision available. "
                "Did you mean to set --int8 as well?")

        if not any([
                self.arg_groups[TrtConfigArgs].tf32,
                self.arg_groups[TrtConfigArgs].fp16,
                self.arg_groups[TrtConfigArgs].int8
        ]):
            G_LOGGER.exit(
                "Please enable at least one precision besides fp32 (e.g. --int8, --fp16, --tf32)"
            )

        if self.arg_groups[ModelArgs].model_type == "engine":
            G_LOGGER.exit(
                "The precision tool cannot work with engines, as they cannot be modified. "
                "Please provide a different format, such as an ONNX or TensorFlow model."
            )

        G_LOGGER.start("Using {:} as higher precision".format(self.precision))

        if args.mode == "linear":
            self.layer_marker = LinearMarker(len(network), args.direction)
        elif args.mode == "bisect":
            self.layer_marker = BisectMarker(len(network), args.direction)
Ejemplo n.º 12
0
 def get_tensor(name):
     if name not in TENSOR_MAP:
         G_LOGGER.exit(
             "Tensor: {:} does not exist in the model.".format(name))
     return TENSOR_MAP[name]
Ejemplo n.º 13
0
 def __call__(self, args):
     G_LOGGER.exit("Encountered an error when loading this tool:\n{:}".format(self.err))
Ejemplo n.º 14
0
    def run(self, args):
        if args.dir is None and (args.good is None or args.bad is None):
            G_LOGGER.exit(
                "Either `--dir`, or both `--good` and `--bad` must be specified."
            )

        def load_tactics(dir):
            """
            Load all tactic replays from the specified directory into a single dictionary.

            Args:
                dir (str): Directory containing zero or more tactic replay files.

            Returns:
                dict[str, Set[polygraphy.backend.trt.algorithm_selector.Algorithm]]:
                        Maps layer names to the set of algorithms present in the tactic replays.
            """
            def try_load_replay(path):
                try:
                    return algorithm_selector.TacticReplayData.load(path)
                except:
                    return None

            tactics = defaultdict(set)
            replay_paths = []
            for path in glob.iglob(os.path.join(dir, "**"), recursive=True):
                replay = try_load_replay(path)
                if replay is None:
                    G_LOGGER.verbose(
                        "{:} does not look like a tactic replay file, skipping."
                        .format(path))
                    continue

                replay_paths.append(path)
                for name, algo in replay.items():
                    tactics[name].add(algo)
            return tactics, replay_paths

        good_dir = util.default(args.good, os.path.join(args.dir, "good"))
        good_tactics, good_paths = load_tactics(good_dir)
        G_LOGGER.info("Loaded {:} good tactic replays.".format(
            len(good_paths)))
        G_LOGGER.verbose("Good tactic replays: {:}".format(good_paths))

        bad_dir = util.default(args.bad, os.path.join(args.dir, "bad"))
        bad_tactics, bad_paths = load_tactics(bad_dir)
        G_LOGGER.info("Loaded {:} bad tactic replays.".format(len(bad_paths)))
        G_LOGGER.verbose("Bad tactic replays: {:}".format(bad_paths))

        # Walk bad tactics and remove all the known good tactics.
        potential_bad_tactics = OrderedDict()
        for name, algo_set in bad_tactics.items():
            if name in good_tactics:
                algo_set -= good_tactics[name]

            if algo_set:
                potential_bad_tactics[name] = algo_set

        if potential_bad_tactics:
            G_LOGGER.info("Found potentially bad tactics:")
            for name, algo_set in potential_bad_tactics.items():
                algo_set_str = list(map(str, algo_set))
                G_LOGGER.info("Layer: {:}\n\tAlgorithms: {:}".format(
                    name, algo_set_str))
        else:
            G_LOGGER.info(
                "Could not determine potentially bad tactics. Try generating more tactic replay files?"
            )
Ejemplo n.º 15
0
def check_subprocess(status):
    if status.returncode:
        G_LOGGER.exit(status.stdout + status.stderr)