示例#1
0
文件: base.py 项目: clayne/TensorRT
    def run(self, args):
        G_LOGGER.start("Starting iterations")

        builder, network, parser = util.unpack_args(
            self.arg_groups[TrtNetworkLoaderArgs].load_network(), 3)

        with contextlib.ExitStack() as stack:
            stack.enter_context(builder)
            stack.enter_context(network)
            if parser:
                stack.enter_context(parser)

            self.setup(args, network)

            num_passed = 0
            num_total = 0

            success = True
            MAX_COUNT = 100000  # We don't want to loop forever. This many iterations ought to be enough for anybody.
            for iteration in range(MAX_COUNT):
                G_LOGGER.start("RUNNING | Iteration {:}".format(iteration + 1))

                self.process_network(network, success)

                # Don't need to keep the engine around in memory - just serialize to disk and free it.
                with self.arg_groups[TrtEngineLoaderArgs].build_engine(
                    (builder, network)) as engine:
                    self.arg_groups[TrtEngineSaveArgs].save_engine(
                        engine,
                        self.arg_groups[ArtifactSorterArgs].iter_artifact)

                success = self.arg_groups[ArtifactSorterArgs].sort_artifacts(
                    iteration + 1)

                num_total += 1
                if success:
                    num_passed += 1

                if self.stop(iteration, success):
                    break
            else:
                G_LOGGER.warning(
                    "Maximum number of iterations reached: {:}.\n"
                    "Iteration has been halted to prevent an infinite loop!".
                    format(MAX_COUNT))

        G_LOGGER.finish(
            "Finished {:} iteration(s) | Passed: {:}/{:} | Pass Rate: {:}%".
            format(iteration + 1, num_passed, num_total,
                   float(num_passed) * 100 / float(num_total)))
示例#2
0
    def call_impl(self):
        """
        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(util.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser
示例#3
0
    def inspect_trt(self, args):
        if self.arg_groups[ModelArgs].model_type == "engine":
            if args.mode != "none":
                G_LOGGER.warning(
                    "Displaying layer information for TensorRT engines is not currently supported"
                )

            with self.arg_groups[TrtEngineLoaderArgs].load_serialized_engine(
            ) as engine:
                engine_str = trt_util.str_from_engine(engine)
                G_LOGGER.info(
                    "==== TensorRT Engine ====\n{:}".format(engine_str))
        else:
            builder, network, parser = util.unpack_args(
                self.arg_groups[TrtNetworkLoaderArgs].load_network(), 3)
            with contextlib.ExitStack() as stack:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser:
                    stack.enter_context(parser)
                network_str = trt_util.str_from_network(
                    network, mode=args.mode).strip()
                G_LOGGER.info(
                    "==== TensorRT Network ====\n{:}".format(network_str))
示例#4
0
    def run(self, args):
        G_LOGGER.start("Starting iterations")

        builder, network, parser = util.unpack_args(
            self.arg_groups[TrtNetworkLoaderArgs].load_network(), 3)

        with contextlib.ExitStack() as stack:
            stack.enter_context(builder)
            stack.enter_context(network)
            if parser:
                stack.enter_context(parser)

            self.setup(args, network)

            num_passed = 0
            num_total = 0

            success = True
            MAX_COUNT = 100000  # We don't want to loop forever. This many iterations ought to be enough for anybody.
            for iteration in range(MAX_COUNT):
                remaining = self.remaining()
                G_LOGGER.start("RUNNING | Iteration {:}{:}".format(
                    iteration + 1,
                    " | Approximately {:} iteration(s) remaining".format(
                        remaining) if remaining is not None else "",
                ))

                self.process_network(network, success)

                try:
                    engine = self.arg_groups[TrtEngineLoaderArgs].build_engine(
                        (builder, network))
                except Exception as err:
                    G_LOGGER.warning(
                        "Failed to create network or engine, continuing to the next iteration.\n"
                        "Note: Error was: {:}".format(err))
                    G_LOGGER.internal_error(
                        "Failed to create network or engine. See warning above for details."
                    )
                    success = False
                else:
                    # Don't need to keep the engine around in memory - just serialize to disk and free it.
                    with engine:
                        self.arg_groups[TrtEngineSaveArgs].save_engine(
                            engine,
                            self.arg_groups[ArtifactSorterArgs].iter_artifact)
                    success = self.arg_groups[
                        ArtifactSorterArgs].sort_artifacts(iteration + 1)

                num_total += 1
                if success:
                    num_passed += 1

                if self.stop(iteration, success):
                    break
            else:
                G_LOGGER.warning(
                    "Maximum number of iterations reached: {:}.\n"
                    "Iteration has been halted to prevent an infinite loop!".
                    format(MAX_COUNT))

        G_LOGGER.finish(
            "Finished {:} iteration(s) | Passed: {:}/{:} | Pass Rate: {:}%".
            format(iteration + 1, num_passed, num_total,
                   float(num_passed) * 100 / float(num_total)))
示例#5
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine that was created.
        """
        # If network is a callable, then we own its return value
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, "
                              "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)
            else:
                provided = "Builder and Network" if parser is None else "Builder, Network, and Parser"
                G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. "
                                 "Please ensure that they are freed.".format(provided))

            config, owns_config = util.invoke_if_callable(self._config, builder, network)
            if owns_config:
                stack.enter_context(config)
            else:
                G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume "
                                 "ownership. Please ensure it is freed.")

            try:
                config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit.
            except AttributeError:
                pass
            else:
                stack.enter_context(config.int8_calibrator)

            network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs"
            G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode)))

            G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config)))

            try:
                engine_bytes = builder.build_serialized_network(network, config)
            except AttributeError:
                engine = builder.build_engine(network, config)
                if not engine:
                    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
                stack.enter_context(engine)
                engine_bytes = engine.serialize()

            if not engine_bytes:
                G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly")

            try:
                timing_cache = config.get_timing_cache()
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork")
            else:
                if timing_cache and self.timing_cache_path:
                    with timing_cache.serialize() as buffer:
                        util.save_file(buffer, self.timing_cache_path, description="tactic timing cache")

            return engine_bytes
示例#6
0
    def call_impl(self):
        """
        Returns:
            onnx.ModelProto: The ONNX-like, but **not** valid ONNX, representation of the TensorRT network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical(
                "Expected to recevie a (builder, network) tuple for the `network` parameter, "
                "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)

            tensor_map = {}

            def tensors_from_meta(meta):
                nonlocal tensor_map
                tensors = []
                for name, (dtype, shape) in meta.items():
                    if name not in tensor_map:
                        tensor_map[name] = gs.Variable(name=name,
                                                       dtype=dtype,
                                                       shape=shape)
                    tensors.append(tensor_map[name])
                return tensors

            nodes = []
            graph_inputs = tensors_from_meta(
                trt_util.get_network_input_metadata(network))
            graph_outputs = tensors_from_meta(
                trt_util.get_network_output_metadata(network))

            LAYER_TYPE_CLASS_MAPPING = trt_util.get_layer_class_mapping()

            for layer in network:
                op_name = layer.type.name
                if layer.type in LAYER_TYPE_CLASS_MAPPING:
                    layer.__class__ = LAYER_TYPE_CLASS_MAPPING[layer.type]

                node_inputs = tensors_from_meta(
                    trt_util.get_layer_input_metadata(layer))
                node_outputs = tensors_from_meta(
                    trt_util.get_layer_output_metadata(layer))
                attrs = {}
                attr_names = trt_util.get_layer_attribute_names(layer)
                for name in attr_names:
                    with G_LOGGER.verbosity():
                        attr = getattr(layer, name)

                    if util.is_sequence(attr) or any(
                            isinstance(attr, cls)
                            for cls in [trt.Dims, trt.Permutation]):
                        try:
                            attr = list(attr)
                        except ValueError:  # Invalid dims
                            attr = []

                    if hasattr(attr, "__entries"):  # TensorRT Enums
                        attr = attr.name

                    if isinstance(attr, trt.ILoop):
                        attr = attr.name

                    VALID_TYPES = [np.ndarray, list, int, str, bool, float]
                    if not any(isinstance(attr, cls) for cls in VALID_TYPES):
                        G_LOGGER.internal_error(
                            "Unknown type: {:} for layer attribute: {:}.\n"
                            "Note: Layer was: {:}".format(
                                type(attr), attr, layer))
                        try:
                            attr = str(attr)
                        except:
                            attr = "<error during conversion>"

                    attrs[name] = attr

                nodes.append(
                    gs.Node(name=layer.name,
                            op=op_name,
                            attrs=attrs,
                            inputs=node_inputs,
                            outputs=node_outputs))

            graph = gs.Graph(name=network.name,
                             inputs=graph_inputs,
                             outputs=graph_outputs,
                             nodes=nodes)

            return gs.export_onnx(graph)
示例#7
0
def test_unpack_args(case):
    args, num, expected = case
    assert util.unpack_args(args, num) == expected