def __init__(self, data): if not ALGO_SELECTOR_ENABLED: trt_util.fail_unavailable("Algorithm selector") # Must explicitly initialize parent for any trampoline class! Will mysteriously segfault without this. IAlgorithmSelector.__init__(self) self.path = None self.data = TacticReplayData() if isinstance(data, TacticReplayData): self.data = data else: self.path = data
def try_run(func, name): try: return func() except AttributeError: trt_util.fail_unavailable("{:} in CreateConfig".format(name))
def call_impl(self): """ Returns: bytes: The serialized engine that was created. """ # If network is a callable, then we own its return value ret, owns_network = util.invoke_if_callable(self._network) builder, network, parser = util.unpack_args(ret, num=3) if builder is None or network is None: G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, " "but received: ({:}, {:})".format(builder, network)) with contextlib.ExitStack() as stack: if owns_network: stack.enter_context(builder) stack.enter_context(network) if parser is not None: stack.enter_context(parser) else: provided = "Builder and Network" if parser is None else "Builder, Network, and Parser" G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. " "Please ensure that they are freed.".format(provided)) config, owns_config = util.invoke_if_callable(self._config, builder, network) if owns_config: stack.enter_context(config) else: G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume " "ownership. Please ensure it is freed.") try: config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit. except AttributeError: pass else: stack.enter_context(config.int8_calibrator) network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs" G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode))) G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config))) try: engine_bytes = builder.build_serialized_network(network, config) except AttributeError: engine = builder.build_engine(network, config) if not engine: G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly") stack.enter_context(engine) engine_bytes = engine.serialize() if not engine_bytes: G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly") try: timing_cache = config.get_timing_cache() except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork") else: if timing_cache and self.timing_cache_path: with timing_cache.serialize() as buffer: util.save_file(buffer, self.timing_cache_path, description="tactic timing cache") return engine_bytes
def call_impl(self, builder, network): """ Args: builder (trt.Builder): The TensorRT builder to use to create the configuration. network (trt.INetworkDefinition): The TensorRT network for which to create the config. The network is used to automatically create a default optimization profile if none are provided. Returns: trt.IBuilderConfig: The TensorRT builder configuration. """ with util.FreeOnException([builder.create_builder_config()]) as (config, ): def try_run(func, name): try: return func() except AttributeError: trt_util.fail_unavailable("{:} in CreateConfig".format(name)) def try_set_flag(flag_name): return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower()) with G_LOGGER.indent(): G_LOGGER.verbose("Setting TensorRT Optimization Profiles") profiles = copy.deepcopy(self.profiles) for profile in profiles: # Last trt_profile is used for set_calibration_profile. trt_profile = profile.fill_defaults(network).to_trt(builder, network) config.add_optimization_profile(trt_profile) G_LOGGER.info("Configuring with profiles: {:}".format(profiles)) config.max_workspace_size = int(self.max_workspace_size) if self.strict_types: try_set_flag("STRICT_TYPES") if self.tf32: try_set_flag("TF32") else: # TF32 is on by default with contextlib.suppress(AttributeError): config.clear_flag(trt.BuilderFlag.TF32) if self.fp16: try_set_flag("FP16") if self.int8: try_set_flag("INT8") if not network.has_explicit_precision: if self.calibrator is not None: input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network) with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method self.calibrator.reset(input_metadata) config.int8_calibrator = self.calibrator try: config.set_calibration_profile(trt_profile) except: G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.") else: G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure " "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.") if self.sparse_weights: try_set_flag("SPARSE_WEIGHTS") if self.tactic_sources is not None: tactic_sources_flag = 0 for source in self.tactic_sources: tactic_sources_flag |= (1 << int(source)) try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources") try: if self.timing_cache_path: timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache") cache = config.create_timing_cache(timing_cache_data) else: # Create an empty timing cache by default so it will be populated during engine build. # This way, consumers of CreateConfig have the option to use the cache later. cache = config.create_timing_cache(b"") except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("load_timing_cache in CreateConfig") else: config.set_timing_cache(cache, ignore_mismatch=False) if self.algorithm_selector is not None: def set_algo_selector(): config.algorithm_selector = self.algorithm_selector try_run(set_algo_selector, "algorithm_selector") return config