Esempio n. 1
0
        def __init__(self, data):
            if not ALGO_SELECTOR_ENABLED:
                trt_util.fail_unavailable("Algorithm selector")

            # Must explicitly initialize parent for any trampoline class! Will mysteriously segfault without this.
            IAlgorithmSelector.__init__(self)

            self.path = None
            self.data = TacticReplayData()
            if isinstance(data, TacticReplayData):
                self.data = data
            else:
                self.path = data
Esempio n. 2
0
 def try_run(func, name):
     try:
         return func()
     except AttributeError:
         trt_util.fail_unavailable("{:} in CreateConfig".format(name))
Esempio n. 3
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine that was created.
        """
        # If network is a callable, then we own its return value
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, "
                              "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)
            else:
                provided = "Builder and Network" if parser is None else "Builder, Network, and Parser"
                G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. "
                                 "Please ensure that they are freed.".format(provided))

            config, owns_config = util.invoke_if_callable(self._config, builder, network)
            if owns_config:
                stack.enter_context(config)
            else:
                G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume "
                                 "ownership. Please ensure it is freed.")

            try:
                config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit.
            except AttributeError:
                pass
            else:
                stack.enter_context(config.int8_calibrator)

            network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs"
            G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode)))

            G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config)))

            try:
                engine_bytes = builder.build_serialized_network(network, config)
            except AttributeError:
                engine = builder.build_engine(network, config)
                if not engine:
                    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
                stack.enter_context(engine)
                engine_bytes = engine.serialize()

            if not engine_bytes:
                G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly")

            try:
                timing_cache = config.get_timing_cache()
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork")
            else:
                if timing_cache and self.timing_cache_path:
                    with timing_cache.serialize() as buffer:
                        util.save_file(buffer, self.timing_cache_path, description="tactic timing cache")

            return engine_bytes
Esempio n. 4
0
    def call_impl(self, builder, network):
        """
        Args:
            builder (trt.Builder):
                    The TensorRT builder to use to create the configuration.
            network (trt.INetworkDefinition):
                    The TensorRT network for which to create the config. The network is used to
                    automatically create a default optimization profile if none are provided.

        Returns:
            trt.IBuilderConfig: The TensorRT builder configuration.
        """
        with util.FreeOnException([builder.create_builder_config()]) as (config, ):
            def try_run(func, name):
                try:
                    return func()
                except AttributeError:
                    trt_util.fail_unavailable("{:} in CreateConfig".format(name))


            def try_set_flag(flag_name):
                return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower())


            with G_LOGGER.indent():
                G_LOGGER.verbose("Setting TensorRT Optimization Profiles")
                profiles = copy.deepcopy(self.profiles)
                for profile in profiles:
                    # Last trt_profile is used for set_calibration_profile.
                    trt_profile = profile.fill_defaults(network).to_trt(builder, network)
                    config.add_optimization_profile(trt_profile)
                G_LOGGER.info("Configuring with profiles: {:}".format(profiles))

            config.max_workspace_size = int(self.max_workspace_size)

            if self.strict_types:
                try_set_flag("STRICT_TYPES")

            if self.tf32:
                try_set_flag("TF32")
            else: # TF32 is on by default
                with contextlib.suppress(AttributeError):
                    config.clear_flag(trt.BuilderFlag.TF32)

            if self.fp16:
                try_set_flag("FP16")

            if self.int8:
                try_set_flag("INT8")
                if not network.has_explicit_precision:
                    if self.calibrator is not None:
                        input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network)
                        with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method
                            self.calibrator.reset(input_metadata)
                        config.int8_calibrator = self.calibrator
                        try:
                            config.set_calibration_profile(trt_profile)
                        except:
                            G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.")
                    else:
                        G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure "
                                         "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.")

            if self.sparse_weights:
                try_set_flag("SPARSE_WEIGHTS")

            if self.tactic_sources is not None:
                tactic_sources_flag = 0
                for source in self.tactic_sources:
                    tactic_sources_flag |= (1 << int(source))
                try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources")

            try:
                if self.timing_cache_path:
                    timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache")
                    cache = config.create_timing_cache(timing_cache_data)
                else:
                    # Create an empty timing cache by default so it will be populated during engine build.
                    # This way, consumers of CreateConfig have the option to use the cache later.
                    cache = config.create_timing_cache(b"")
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("load_timing_cache in CreateConfig")
            else:
                config.set_timing_cache(cache, ignore_mismatch=False)

            if self.algorithm_selector is not None:
                def set_algo_selector():
                    config.algorithm_selector = self.algorithm_selector
                try_run(set_algo_selector, "algorithm_selector")

            return config