示例#1
0
def check_outputs_not_found(not_found, available_outputs):
    if not_found:
        available_outputs = util.unique_list(available_outputs)
        G_LOGGER.critical(
            "The following outputs were not found: {:}.\n"
            "Note: Available tensors:\n\t{:}".format(not_found, "\n\t".join(available_outputs))
        )
示例#2
0
        def run_const_fold_pass(model):
            graph = gs_from_onnx(model)
            del model

            try:
                graph.fold_constants(fold_shapes=self.fold_shapes,
                                     partitioning=self.partitioning)
            except TypeError as err:  # Using an old version of ONNX-GS
                if self.partitioning:
                    G_LOGGER.critical(
                        "This version of ONNX-GraphSurgeon may not support partitioning the graph.\n"
                        "Please upgrade to a newer version of ONNX-GraphSurgeon or disable partitioning.\n"
                        "Note: Error was:\n{:}".format(err))
                if self.fold_shapes:
                    G_LOGGER.critical(
                        "This version of ONNX-GraphSurgeon may not support folding shapes.\n"
                        "Please upgrade to a newer version of ONNX-GraphSurgeon or disable shape folding.\n"
                        "Note: Error was:\n{:}".format(err))

                graph.fold_constants()

            model = gs.export_onnx(graph.cleanup(), do_type_check=False)
            del graph

            if self.fold_shapes and self.do_shape_inference:
                model = infer_shapes(model)
            return model
示例#3
0
 def check_decoded(obj):
     if not isinstance(obj, cls):
         G_LOGGER.critical(
             "Provided JSON cannot be decoded into a {:}.\n"
             "Note: JSON was decoded into a {:}:\n{:}".format(cls.__name__, type(obj), obj)
         )
     return obj
示例#4
0
    def __init__(self, mode=None, prefix=None, suffix=None):
        """
        Args:
            mode (str): The mode to use when opening the file.
            prefix (str): The prefix to use for the file path.
            suffix (str): The suffix to use for the file path.
        """
        self.mode = default(mode, "wb+")
        prefix = default(prefix, "")
        suffix = default(suffix, "")

        def rand_path():
            return os.path.join(tempfile.gettempdir(), "{:}{:}{:}".format(prefix, os.urandom(24).hex(), suffix))

        # In the unlikely event the path exists, generate a new one. Only try 100 times so
        # we don't end up in an infinite loop.
        path = rand_path()
        for _ in range(100):
            if not os.path.exists(path):
                break
            path = rand_path()
        else:
            G_LOGGER.critical("Could not create a temporary file under: {:}".format(tempfile.gettempdir()))

        self.name = path  # Use 'name' to be compatible with tempfile.NamedTemporaryFile
        open(self.name, "x").close()
        self._fhandle = None
示例#5
0
    def call_impl(self):
        uff_model, input_names, input_shapes, output_names = self.uff_loader()

        builder = trt.Builder(get_trt_logger())
        network = builder.create_network()
        parser = trt.UffParser()
        # Input names should come from the converter, as a preprocessing script may have been applied to the frozen model.
        for name, shape in zip(input_names, input_shapes):
            # Default order is NCHW, only set to NHWC if we're reasonably certain that it is.
            input_order = self.uff_order
            if not self.uff_order:
                input_order = trt.UffInputOrder.NCHW
                if FormatManager.determine_format(shape) == DataFormat.NHWC:
                    input_order = trt.UffInputOrder.NHWC
            shape = shape[1:]
            G_LOGGER.verbose(
                "Registering UFF input: {:} with shape: {:} and input order: {:}"
                .format(name, shape, input_order))
            parser.register_input(name, shape, input_order)

        if output_names and output_names != constants.MARK_ALL:
            for name in output_names:
                G_LOGGER.verbose("Registering UFF output: " + str(name))
                parser.register_output(name)

        G_LOGGER.info(
            "Parsing UFF model with inputs: {:} and outputs: {:}".format(
                input_names, output_names))
        success = parser.parse_buffer(uff_model, network)
        if not success:
            G_LOGGER.critical("Could not parse UFF correctly")
        return builder, network, parser, input_shapes[0][0]
示例#6
0
    def run(self, args):
        if not args.convert_to:
            _, ext = os.path.splitext(args.output)
            if ext not in ModelArgs.EXT_MODEL_TYPE_MAPPING:
                G_LOGGER.critical(
                    "Could not automatically determine model type based on output path: {:}\n"
                    "Please specify the desired output format with --convert-to"
                    .format(args.output))
            convert_type = ModelArgs.ModelType(
                ModelArgs.EXT_MODEL_TYPE_MAPPING[ext])
        elif args.convert_to == "onnx-like-trt-network":
            convert_type = "onnx-like-trt-network"
        else:
            CONVERT_TO_MODEL_TYPE_MAPPING = {"onnx": "onnx", "trt": "engine"}
            convert_type = ModelArgs.ModelType(
                CONVERT_TO_MODEL_TYPE_MAPPING[args.convert_to])

        if convert_type == "onnx-like-trt-network":
            onnx_like = trt_backend.onnx_like_from_network(
                self.arg_groups[TrtNetworkLoaderArgs].get_network_loader())
            onnx_backend.save_onnx(onnx_like, args.output)
        elif convert_type.is_onnx():
            model = self.arg_groups[OnnxLoaderArgs].load_onnx()
            if args.fp_to_fp16:
                model = onnx_backend.convert_to_fp16(model)
            self.arg_groups[OnnxSaveArgs].save_onnx(model, args.output)
        elif convert_type.is_trt():
            with self.arg_groups[TrtEngineLoaderArgs].build_engine() as engine:
                self.arg_groups[TrtEngineSaveArgs].save_engine(
                    engine, args.output)
        else:
            G_LOGGER.critical(
                "Cannot convert to model type: {:}".format(convert_type))
示例#7
0
def load_graph(path):
    """
    Loads a TensorFlow frozen model.

    Args:
        path (Union[str, tf.Graph, tf.GraphDef]):
                A path to the frozen model, or a frozen TensorFlow graph or graphdef.

    Returns:
        tf.Graph: The TensorFlow graph
    """
    if isinstance(path, tf.Graph):
        return path

    if isinstance(path, str):
        graphdef = tf.compat.v1.GraphDef()

        import google

        try:
            graphdef.ParseFromString(
                util.load_file(path, description="GraphDef"))
        except google.protobuf.message.DecodeError:
            G_LOGGER.backtrace()
            G_LOGGER.critical(
                "Could not import TensorFlow GraphDef from: {:}. Is this a valid TensorFlow model?"
                .format(path))
    elif isinstance(path, tf.compat.v1.GraphDef):
        graphdef = path

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graphdef, name="")
        return graph
示例#8
0
文件: serde.py 项目: clayne/TensorRT
        def register_impl(func):
            def add(key, val):
                if key in cls.polygraphy_registered:
                    G_LOGGER.critical("Duplicate serialization function for type: {:}.\n"
                                      "Note: Existing function: {:}, New function: {:}".format(
                                        key, cls.polygraphy_registered[key], func))
                cls.polygraphy_registered[key] = val


            if cls == Encoder:
                def wrapped(obj):
                    dct = func(obj)
                    dct[str_from_type(typ)] = constants.TYPE_MARKER
                    return dct

                add(typ, wrapped)
                return wrapped
            elif cls == Decoder:
                def wrapped(dct):
                    del dct[str_from_type(typ)]
                    return func(dct)

                add(str_from_type(typ), wrapped)
            else:
                G_LOGGER.critical("Cannot register for unrecognized class type: ")
示例#9
0
    def add_onnx_loader(self, script, disable_custom_outputs=None, suffix=None):
        model_type = self.model_args.model_type
        if model_type.is_onnx():
            loader_name = self.model_args.model_file
            if self.onnx_shape_inference_args is not None:
                loader_name = self.onnx_shape_inference_args.add_to_script(script, loader_name)

            if loader_name == self.model_args.model_file:  # Shape inference loader isn't being used, have to load.
                script.add_import(imports=["OnnxFromPath"], frm="polygraphy.backend.onnx")
                loader_str = make_invocable(
                    "OnnxFromPath", self.model_args.model_file, external_data_dir=self.load_external_data
                )
                loader_name = script.add_loader(loader_str, "load_onnx", suffix=suffix)
        elif model_type.is_tf():
            if self.tf2onnx_loader_args is None:
                G_LOGGER.critical("Could not load: {:}. Is it an ONNX model?".format(self.model_args.model_file))
            loader_name = self.tf2onnx_loader_args.add_to_script(script)
        else:
            G_LOGGER.critical("Model type: {:} cannot be converted to ONNX.".format(model_type))

        loader_name = self._get_modify_onnx_loader(script, loader_name, disable_custom_outputs=disable_custom_outputs)

        if self.onnx_save_args is not None:
            loader_name = self.onnx_save_args.add_save_onnx(script, loader_name)

        return loader_name
示例#10
0
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = args_util.get(args, "runners", default=[])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option".format(args.model_file)
            )
示例#11
0
    def infer_impl(self, feed_dict):
        start = time.time()
        [
            self.input_buffers[name].device.copy_from(buffer, self.stream)
            for name, buffer in feed_dict.items()
        ]
        # We will not run with smaller batch sizes than whatever the builder chose.
        bindings = [buf.device.ptr for buf in self.input_buffers.values()] + [
            buf.device.ptr for buf in self.output_buffers.values()
        ]
        status = self.context.execute_async(
            batch_size=self.context.engine.max_batch_size,
            bindings=bindings,
            stream_handle=self.stream.ptr)
        if not status:
            G_LOGGER.critical(
                "Model execution failed. Please see the log messages above for details"
            )

        for out in self.output_buffers.values():
            out.host = out.device.copy_to(out.host, self.stream)

        self.stream.synchronize()
        end = time.time()

        out_dict = OrderedDict()
        for (name, out) in self.output_buffers.items():
            out_dict[name] = out.host
        self.inference_time = end - start
        return out_dict
示例#12
0
def parse_num_bytes(num_bytes_arg):
    """
    Parses an argument that indicates a number of bytes. The argument may use scientific notation,
    or contain a `K`, `M`, or `G` suffix (case-insensitive), indicating `KiB`, `MiB`, or `GiB` respectively.
    If the number is fractional, it will be truncated to the nearest integer value.

    If the provided argument is `None`, `None` is returned.

    Args:
        num_bytes_arg (str): The argument indicating the number of bytes.

    Returns:
        int: The number of bytes.
    """
    if num_bytes_arg is None:
        return None

    num_component = num_bytes_arg  # Numerical component of the argument
    multiplier = 1

    suffix_mulitplier = {"K": 1 << 10, "M": 1 << 20, "G": 1 << 30}
    for suffix, mult in suffix_mulitplier.items():
        if num_bytes_arg.upper().endswith(suffix):
            num_component = num_bytes_arg.upper().rstrip(suffix)
            multiplier = mult
            break

    try:
        return int(float(num_component) * multiplier)
    except:
        G_LOGGER.critical(
            "Could not convert {:} to a number of bytes. "
            "Please use either an integer (e.g. 16000000), scientific notation (e.g. 16e6), "
            "or a number with a valid suffix: K, M, or G (e.g. 16M).".format(
                num_bytes_arg))
示例#13
0
    def set_profile(self, index):
        """
        Sets the active optimization profile for this runner.
        The runner must already be active (see ``__enter__()`` or ``activate()``).

        This only applies if your engine was built with multiple
        optimization profiles.

        In TensorRT 8.0 and newer, the profile will be set asynchronously
        using this runner's CUDA stream (``runner.stream``).

        By default, the runner uses the first profile (profile 0).

        Args:
            index (int):
                    The index of the optimization profile to use.
        """
        if not self.is_active:
            G_LOGGER.critical("{:35} | Must be activated prior to calling set_profile()".format(self.name))

        try:
            self.context.set_optimization_profile_async
        except AttributeError:
            self.context.active_optimization_profile = index
        else:
            self.context.set_optimization_profile_async(index, self.stream.ptr)
示例#14
0
 def setup(self, args, network):
     try:
         self.until = int(args.until) - 1
     except:
         self.until = args.until
         if self.until not in ["good", "bad"]:
             G_LOGGER.critical("--until value must be an integer, 'good', or 'bad', but was: {:}".format(args.until))
示例#15
0
        def extended_func(*args, **kwargs):
            extend_func_retval = extend_func(*args, **kwargs)
            extend_func_ret_tuple = make_iterable(extend_func_retval)

            func_args = inspect.signature(func).parameters
            # Special case for when the extended function does not return anything
            if len(func_args) == 0 and len(extend_func_ret_tuple) == 1 and extend_func_ret_tuple[0] is None:
                func_retval = func()
            elif len(extend_func_ret_tuple) == len(func_args):
                func_retval = func(*extend_func_ret_tuple)
            else:

                def try_get_name(fn):
                    try:
                        return fn.__name__
                    except:
                        return fn

                G_LOGGER.critical(
                    "Function: {:} accepts {:} parameter(s), but "
                    "needs to accept {:} parameter(s) from: {:} instead.\nNote: Parameters should be: {:}".format(
                        try_get_name(func),
                        len(func_args),
                        len(extend_func_ret_tuple),
                        try_get_name(extend_func),
                        tuple(map(type, extend_func_ret_tuple)),
                    )
                )

            if func_retval is not None:
                return func_retval
            return extend_func_retval
示例#16
0
def parse_profile_shapes(default_shapes, min_args, opt_args, max_args):
    """
    Parses TensorRT profile options from command-line arguments.

    Args:
        default_shapes (TensorMetadata): The inference input shapes.

    Returns:
     List[Tuple[OrderedDict[str, Shape]]]:
            A list of profiles with each profile comprised of three dictionaries
            (min, opt, max) mapping input names to shapes.
    """
    def get_shapes(lst, idx):
        nonlocal default_shapes
        default_shapes = copy.copy(default_shapes)
        if idx < len(lst):
            default_shapes.update(
                args_util.parse_meta(lst[idx], includes_dtype=False))

        # Don't care about dtype, and need to override dynamic dimensions
        shapes = {
            name: util.override_dynamic_shape(shape)
            for name, (_, shape) in default_shapes.items()
        }

        for name, shape in shapes.items():
            if tuple(default_shapes[name].shape) != tuple(shape):
                G_LOGGER.warning(
                    "Input tensor: {:} | For TensorRT profile, overriding dynamic shape: {:} to: {:}"
                    .format(name, default_shapes[name].shape, shape),
                    mode=LogMode.ONCE,
                )

        return shapes

    num_profiles = max(len(min_args), len(opt_args), len(max_args))

    # For cases where input shapes are provided, we have to generate a profile
    if not num_profiles and default_shapes:
        num_profiles = 1

    profiles = []
    for idx in range(num_profiles):
        min_shapes = get_shapes(min_args, idx)
        opt_shapes = get_shapes(opt_args, idx)
        max_shapes = get_shapes(max_args, idx)
        if sorted(min_shapes.keys()) != sorted(opt_shapes.keys()):
            G_LOGGER.critical(
                "Mismatch in input names between minimum shapes ({:}) and optimum shapes "
                "({:})".format(list(min_shapes.keys()),
                               list(opt_shapes.keys())))
        elif sorted(opt_shapes.keys()) != sorted(max_shapes.keys()):
            G_LOGGER.critical(
                "Mismatch in input names between optimum shapes ({:}) and maximum shapes "
                "({:})".format(list(opt_shapes.keys()),
                               list(max_shapes.keys())))

        profiles.append((min_shapes, opt_shapes, max_shapes))
    return profiles
示例#17
0
文件: script.py 项目: clayne/TensorRT
 def __iadd__(self, other):
     if config.INTERNAL_CORRECTNESS_CHECKS:
         if not isinstance(other, Script.String):
             G_LOGGER.critical("Cannot concatenate str and Script.String. Note: str was: {:}".format(other))
         elif self.safe != other.safe:
             G_LOGGER.critical("Cannot concatenate unsafe string ({:}) to safe string ({:})!".format(other, self.s))
     self.s += other.s
     return self
示例#18
0
def check_onnx_parser_errors(parser, success):
    if parser.num_errors > 0:
        for index in range(parser.num_errors):
            G_LOGGER.error(parser.get_error(index))
        G_LOGGER.critical("Could not parse ONNX correctly")

    if not success:
        G_LOGGER.critical("Failed to parse ONNX model. Does the model file exist and contain a valid ONNX model?")
示例#19
0
 def add(key, val):
     if key in cls.polygraphy_registered:
         G_LOGGER.critical(
             "Duplicate serialization function for type: {:}.\n"
             "Note: Existing function: {:}, New function: {:}".format(
                 key, cls.polygraphy_registered[key], func
             )
         )
     cls.polygraphy_registered[key] = val
示例#20
0
 def validate_meta(meta):
     try:
         for (fmt, dtype) in meta:
             assert isinstance(fmt, trt.TensorFormat)
             assert isinstance(dtype, trt.DataType)
     except:
         G_LOGGER.critical("Could not validate input/output metadata: {:}. "
                           "Is it a list of tuples containing (trt.TensorFormat, trt.DataType)?".format(meta))
     return meta
示例#21
0
    def __getitem__(self, key):
        if isinstance(key, int):
            return self.lst[key]

        for name, iteration_results in self.lst:
            if name == key:
                return iteration_results

        G_LOGGER.critical("{:35} does not exist in this RunResults instance. Note: Available runners: {:}".format(
                        key, list(self.keys())))
示例#22
0
 def pop_meta(name):
     nonlocal tensor_meta_arg
     tensor_meta_arg, _, val = tensor_meta_arg.rpartition(SEP)
     if not tensor_meta_arg:
         G_LOGGER.critical(
             "Could not parse {:} from argument: {:}. Is it separated by a comma "
             "(,) from the tensor name?".format(name,
                                                orig_tensor_meta_arg))
     if val.lower() == "auto":
         val = None
     return val
示例#23
0
def assert_identifier(inp):
    """
    Checks if the argument can be a valid Python identifier.

    Raises a PolygraphyException if it can't.
    """
    if not inp.isidentifier():
        G_LOGGER.critical(
            "This argument must be a valid identifier. "
            "Provided argument cannot be a Python identifier: {:}".format(inp))
    return inp
示例#24
0
 def get_static_shape(name, shape):
     static_shape = shape
     if util.is_shape_dynamic(shape):
         static_shape = util.override_dynamic_shape(shape)
         if static_shape != shape and name not in self.user_input_metadata:
             if not util.is_valid_shape_override(static_shape, shape):
                 G_LOGGER.critical("Input tensor: {:} | Cannot override original shape: {:} to {:}".format(name, shape, static_shape))
             G_LOGGER.warning("Input tensor: {:} | Will generate data of shape: {:}.\n"
                              "If this is incorrect, please set input_metadata "
                              "or provide a custom data loader.".format(name, static_shape), mode=LogMode.ONCE)
     return static_shape
示例#25
0
    def parse(self, args):
        def determine_model_type():
            if args_util.get(args, "model_type") is not None:
                return args.model_type.lower()

            if args_util.get(args, "model_file") is None:
                return None

            def use_ext(ext_mapping):
                file_ext = os.path.splitext(args.model_file)[-1]
                if file_ext in ext_mapping:
                    return ext_mapping[file_ext]

            runners = args_util.get(args, "runners", default=[])
            if args_util.get(args, "ckpt") or os.path.isdir(args.model_file):
                return "ckpt"
            elif "tf" in runners or "trt_legacy" in runners:
                if args.caffe_model:
                    return "caffe"
                return use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING) or "frozen"
            else:
                model_type = use_ext(ModelArgs.EXT_MODEL_TYPE_MAPPING)
                if model_type:
                    return model_type

            G_LOGGER.critical(
                "Could not automatically determine model type for: {:}\n"
                "Please explicitly specify the type with the --model-type option".format(args.model_file)
            )

        if args_util.get(args, "input_shapes"):
            self.input_shapes = args_util.parse_meta(
                args_util.get(args, "input_shapes"), includes_dtype=False
            )  # TensorMetadata
        else:
            self.input_shapes = TensorMetadata()

        self.model_file = args_util.get(args, "model_file")

        if self.model_file:
            G_LOGGER.verbose("Model: {:}".format(self.model_file))
            if not os.path.exists(self.model_file):
                G_LOGGER.warning("Model path does not exist: {:}".format(self.model_file))
            self.model_file = os.path.abspath(self.model_file)

        model_type_str = self._model_type if self._model_type else determine_model_type()
        self.model_type = ModelArgs.ModelType(model_type_str) if model_type_str else None

        if self.model_type == "trt-network-script" and (not self.model_file or not self.model_file.endswith(".py")):
            G_LOGGER.critical(
                "TensorRT network scripts must exist and have '.py' extensions.\n"
                "Note: Provided network script path was: {:}".format(self.model_file)
            )
示例#26
0
        def select_algorithms(self, context, choices):
            """
            Selects an algorithm based on ``self.data`` if possible. Otherwise, returns
            default tactics.

            Args:
                context (trt.IAlgorithmContext):
                        The TensorRT algorithm context.
                choices (List[trt.IAlgorithm]):
                        A list of TensorRT algorithm choices.

            Returns:
                List[int]:
                        The indices of selected tactics. If ``self.data`` includes the layer and
                        TensorRT provides a matching tactic, this will always be of length 1.

            Raises:
                PolygraphyException:
                        If a tactic is set for a layer in ``self.data`` but is not provided by
                        TensorRT as a choice for that layer.
            """
            default_choices = super().select_algorithms(context, choices)

            if not self.data:  # No replay data, we are in recording mode.
                return default_choices

            if context.name not in self.data:
                G_LOGGER.warning(
                    "Layer: {:} was not found in the tactic replay. Falling back to default tactics.".format(
                        context.name
                    )
                )
                G_LOGGER.warning(
                    "Has the network changed since the tactic replay file was generated?\n"
                    "Note: Layers in the tactic replay are:\n\t{:}".format("\n\t".join(self.data.keys())),
                    mode=LogMode.ONCE,
                )
                return default_choices

            # Need to find the index of the tactic we want.
            to_select = self.data[context.name]
            tactic_choices = [Algorithm.from_trt(context, algo) for algo in choices]

            if to_select not in tactic_choices:
                G_LOGGER.critical(
                    "Layer: {:} | Tactic in replay was not provided by TensorRT as a choice for this layer.\n"
                    "Has the network or builder configuration changed since the replay file was generated?\n"
                    "Note: Tactic in replay was:\n\t{:}\nProvided choices were:\n\t{:}".format(
                        context.name, to_select, "\n\t".join(map(str, tactic_choices))
                    )
                )

            return [tactic_choices.index(to_select)]
示例#27
0
    def find(self):
        def run(indices):
            self.mark_layers(indices)
            return self.check_network("-".join(map(str, indices)))

        # Finds num worst indices in acc_results
        def find_worst(num, acc_results):
            acc_mapping = list(acc_results.values())[0][
                0]  # First iteration of first runner-pair.

            # Compute for each layer: atol / prev_atol, to determine which layers contribute the greatest error.
            # It is not enough to simply find the max(atol), because that doesn't account for error introduced
            # by previous layers.
            items = list(acc_mapping.items())
            ratios = []
            for (_, prev_tols), (outname,
                                 cur_tols) in zip(items[:-1], items[1:]):
                ratio = cur_tols.max_absdiff / prev_tols.max_absdiff
                ratios.append((ratio, outname))

            # Mark more layers on each iteration
            ratios = sorted(ratios, reverse=True)[:num]
            G_LOGGER.verbose(
                "Found worst {:} layers (Format: (error ratio, tensor name)): {:}"
                .format(num, ratios))
            return [output_mapping[outname] for (ratio, outname) in ratios]

        if not self.makers[TrtLoaderArgs].outputs:
            G_LOGGER.critical(
                "worst-first requires all outputs to be marked as network outputs mode to determine where errors are being introduced. "
                "Please enable --trt-outputs mark all, and ensure that your golden outputs also include layer-wise results"
            )

        output_mapping = {
        }  # Maps output tensor names to producer layer indices
        for layer_index, layer in enumerate(self.network):
            for out_index in range(layer.num_outputs):
                output_mapping[layer.get_output(out_index).name] = layer_index

        indices = []
        acc_results = run(indices)
        max_outputs = len(list(acc_results.values())[0][0]) - 1

        iter_num = 0
        # indices will be at most one less than the number of layers, since we're comparing layers against subsequent ones.
        while not bool(acc_results) and len(indices) < max_outputs:
            iter_num += 1
            indices = find_worst(self.args.top * iter_num, acc_results)
            acc_results = run(indices)

        if bool(acc_results):
            return indices
示例#28
0
    def run(self, args):
        if self.arg_groups[ModelArgs].model_file is None and args.runners:
            G_LOGGER.critical(
                "One or more runners was specified, but no model file was provided. Make sure you've specified the model path, "
                "and also that it's not being consumed as an argument for another parameter"
            )

        script = self.build_script(args)

        if args.gen_script:
            script.save(args.gen_script)
        else:
            exec(str(script))
示例#29
0
 def validate_meta(meta):
     for (fmt, dtype) in meta:
         if not isinstance(fmt, trt.TensorFormat):
             G_LOGGER.critical(
                 "'format' must be an instance of trt.TensorFormat, but is: {:}.\n"
                 "Note: Provided input/output metadata was: {:}".format(fmt, meta)
             )
         if not isinstance(dtype, trt.DataType):
             G_LOGGER.critical(
                 "'dtype' must be an instance of trt.DataType, but is: {:}.\n"
                 "Note: Provided input/output metadata was: {:}".format(dtype, meta)
             )
     return meta
示例#30
0
    def __getitem__(self, key):
        """
        Retrieves the shapes registered for a given input name.

        Returns:
            ShapeTuple:
                    A named tuple including ``min``, ``opt``, and ``max`` members for the shapes
                    corresponding to the input.
        """
        if key not in self:
            G_LOGGER.critical(
                "Binding: {:} does not have shapes set in this profile".format(
                    key))
        return super().__getitem__(key)