Example #1
0
 def call_impl(self):
     with util.FreeOnException(
             create_network(
                 explicit_precision=self.explicit_precision,
                 explicit_batch=self.explicit_batch)) as (builder, network):
         parser = trt.OnnxParser(network, trt_util.get_trt_logger())
         return builder, network, parser
Example #2
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     with util.FreeOnException(super().call_impl()) as (builder, network, parser):
         success = parser.parse(util.invoke_if_callable(self._model_bytes)[0])
         trt_util.check_onnx_parser_errors(parser, success)
         return builder, network, parser
Example #3
0
    def call_impl(self):
        """
        Returns:
            trt.ICudaEngine: The engine that was saved.
        """
        engine, owns_engine = util.invoke_if_callable(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(util.FreeOnException([engine]))

            util.save_file(contents=bytes_from_engine(engine), dest=self.path, description="engine")
            return engine
Example #4
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine.
        """
        engine, owns_engine = util.invoke_if_callable(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(util.FreeOnException([engine]))

            with engine.serialize() as buffer:
                return bytes(buffer)
Example #5
0
    def call_impl(self):
        from polygraphy.backend.onnx import util as onnx_util

        with util.FreeOnException(super().call_impl()) as (builder, network,
                                                           parser):
            onnx_model, _ = util.invoke_if_callable(self.onnx_loader)
            _, shape = list(
                onnx_util.get_input_metadata(onnx_model.graph).values())[0]

            success = parser.parse(onnx_model.SerializeToString())
            trt_util.check_onnx_parser_errors(parser, success)

            return builder, network, parser, shape[0]
Example #6
0
 def call_impl(self):
     """
     Returns:
         (trt.Builder, trt.INetworkDefinition): The builder and empty network.
     """
     with util.FreeOnException([trt.Builder(trt_util.get_trt_logger())]) as (builder, ):
         network_flags = 0
         if self.explicit_batch:
             network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
         if self.explicit_precision:
             network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
         network = builder.create_network(flags=network_flags)
         if network is None:
             G_LOGGER.critical("Invalid network. See logging output above for details.")
         return builder, network
Example #7
0
 def call_impl(self):
     """
     Returns:
         (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser):
                 A TensorRT network, as well as the builder used to create it, and the parser
                 used to populate it.
     """
     path = util.invoke_if_callable(self.path)[0]
     if mod.version(trt.__version__) >= mod.version("7.1"):
         with util.FreeOnException(super().call_impl()) as (builder, network, parser):
             # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for
             # potentially parsing any external weights.
             success = parser.parse_from_file(path)
             trt_util.check_onnx_parser_errors(parser, success)
             return builder, network, parser
     else:
         from polygraphy.backend.common import bytes_from_path
         return network_from_onnx_bytes(bytes_from_path(path), self.explicit_precision)
Example #8
0
    def call_impl(self):
        """
        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(util.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser
Example #9
0
    def call_impl(self, builder, network):
        """
        Args:
            builder (trt.Builder):
                    The TensorRT builder to use to create the configuration.
            network (trt.INetworkDefinition):
                    The TensorRT network for which to create the config. The network is used to
                    automatically create a default optimization profile if none are provided.

        Returns:
            trt.IBuilderConfig: The TensorRT builder configuration.
        """
        with util.FreeOnException([builder.create_builder_config()]) as (config, ):
            def try_run(func, name):
                try:
                    return func()
                except AttributeError:
                    trt_util.fail_unavailable("{:} in CreateConfig".format(name))


            def try_set_flag(flag_name):
                return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower())


            with G_LOGGER.indent():
                G_LOGGER.verbose("Setting TensorRT Optimization Profiles")
                profiles = copy.deepcopy(self.profiles)
                for profile in profiles:
                    # Last trt_profile is used for set_calibration_profile.
                    trt_profile = profile.fill_defaults(network).to_trt(builder, network)
                    config.add_optimization_profile(trt_profile)
                G_LOGGER.info("Configuring with profiles: {:}".format(profiles))

            config.max_workspace_size = int(self.max_workspace_size)

            if self.strict_types:
                try_set_flag("STRICT_TYPES")

            if self.tf32:
                try_set_flag("TF32")
            else: # TF32 is on by default
                with contextlib.suppress(AttributeError):
                    config.clear_flag(trt.BuilderFlag.TF32)

            if self.fp16:
                try_set_flag("FP16")

            if self.int8:
                try_set_flag("INT8")
                if not network.has_explicit_precision:
                    if self.calibrator is not None:
                        input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network)
                        with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method
                            self.calibrator.reset(input_metadata)
                        config.int8_calibrator = self.calibrator
                        try:
                            config.set_calibration_profile(trt_profile)
                        except:
                            G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.")
                    else:
                        G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure "
                                         "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.")

            if self.sparse_weights:
                try_set_flag("SPARSE_WEIGHTS")

            if self.tactic_sources is not None:
                tactic_sources_flag = 0
                for source in self.tactic_sources:
                    tactic_sources_flag |= (1 << int(source))
                try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources")

            try:
                if self.timing_cache_path:
                    timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache")
                    cache = config.create_timing_cache(timing_cache_data)
                else:
                    # Create an empty timing cache by default so it will be populated during engine build.
                    # This way, consumers of CreateConfig have the option to use the cache later.
                    cache = config.create_timing_cache(b"")
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("load_timing_cache in CreateConfig")
            else:
                config.set_timing_cache(cache, ignore_mismatch=False)

            if self.algorithm_selector is not None:
                def set_algo_selector():
                    config.algorithm_selector = self.algorithm_selector
                try_run(set_algo_selector, "algorithm_selector")

            return config