def call_impl(self): with util.FreeOnException( create_network( explicit_precision=self.explicit_precision, explicit_batch=self.explicit_batch)) as (builder, network): parser = trt.OnnxParser(network, trt_util.get_trt_logger()) return builder, network, parser
def call_impl(self): """ Returns: (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser): A TensorRT network, as well as the builder used to create it, and the parser used to populate it. """ with util.FreeOnException(super().call_impl()) as (builder, network, parser): success = parser.parse(util.invoke_if_callable(self._model_bytes)[0]) trt_util.check_onnx_parser_errors(parser, success) return builder, network, parser
def call_impl(self): """ Returns: trt.ICudaEngine: The engine that was saved. """ engine, owns_engine = util.invoke_if_callable(self._engine) with contextlib.ExitStack() as stack: if owns_engine: stack.enter_context(util.FreeOnException([engine])) util.save_file(contents=bytes_from_engine(engine), dest=self.path, description="engine") return engine
def call_impl(self): """ Returns: bytes: The serialized engine. """ engine, owns_engine = util.invoke_if_callable(self._engine) with contextlib.ExitStack() as stack: if owns_engine: stack.enter_context(util.FreeOnException([engine])) with engine.serialize() as buffer: return bytes(buffer)
def call_impl(self): from polygraphy.backend.onnx import util as onnx_util with util.FreeOnException(super().call_impl()) as (builder, network, parser): onnx_model, _ = util.invoke_if_callable(self.onnx_loader) _, shape = list( onnx_util.get_input_metadata(onnx_model.graph).values())[0] success = parser.parse(onnx_model.SerializeToString()) trt_util.check_onnx_parser_errors(parser, success) return builder, network, parser, shape[0]
def call_impl(self): """ Returns: (trt.Builder, trt.INetworkDefinition): The builder and empty network. """ with util.FreeOnException([trt.Builder(trt_util.get_trt_logger())]) as (builder, ): network_flags = 0 if self.explicit_batch: network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) if self.explicit_precision: network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) network = builder.create_network(flags=network_flags) if network is None: G_LOGGER.critical("Invalid network. See logging output above for details.") return builder, network
def call_impl(self): """ Returns: (trt.IBuilder, trt.INetworkDefinition, trt.OnnxParser): A TensorRT network, as well as the builder used to create it, and the parser used to populate it. """ path = util.invoke_if_callable(self.path)[0] if mod.version(trt.__version__) >= mod.version("7.1"): with util.FreeOnException(super().call_impl()) as (builder, network, parser): # We need to use parse_from_file for the ONNX parser to keep track of the location of the ONNX file for # potentially parsing any external weights. success = parser.parse_from_file(path) trt_util.check_onnx_parser_errors(parser, success) return builder, network, parser else: from polygraphy.backend.common import bytes_from_path return network_from_onnx_bytes(bytes_from_path(path), self.explicit_precision)
def call_impl(self): """ Returns: trt.INetworkDefinition: The modified network. """ ret, owns_network = util.invoke_if_callable(self._network) builder, network, parser = util.unpack_args(ret, num=3) with contextlib.ExitStack() as stack: if owns_network: stack.enter_context(util.FreeOnException([builder, network, parser])) if self.outputs == constants.MARK_ALL: trt_util.mark_layerwise(network) elif self.outputs is not None: trt_util.mark_outputs(network, self.outputs) if self.exclude_outputs is not None: trt_util.unmark_outputs(network, self.exclude_outputs) if parser is None: return builder, network return builder, network, parser
def call_impl(self, builder, network): """ Args: builder (trt.Builder): The TensorRT builder to use to create the configuration. network (trt.INetworkDefinition): The TensorRT network for which to create the config. The network is used to automatically create a default optimization profile if none are provided. Returns: trt.IBuilderConfig: The TensorRT builder configuration. """ with util.FreeOnException([builder.create_builder_config()]) as (config, ): def try_run(func, name): try: return func() except AttributeError: trt_util.fail_unavailable("{:} in CreateConfig".format(name)) def try_set_flag(flag_name): return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower()) with G_LOGGER.indent(): G_LOGGER.verbose("Setting TensorRT Optimization Profiles") profiles = copy.deepcopy(self.profiles) for profile in profiles: # Last trt_profile is used for set_calibration_profile. trt_profile = profile.fill_defaults(network).to_trt(builder, network) config.add_optimization_profile(trt_profile) G_LOGGER.info("Configuring with profiles: {:}".format(profiles)) config.max_workspace_size = int(self.max_workspace_size) if self.strict_types: try_set_flag("STRICT_TYPES") if self.tf32: try_set_flag("TF32") else: # TF32 is on by default with contextlib.suppress(AttributeError): config.clear_flag(trt.BuilderFlag.TF32) if self.fp16: try_set_flag("FP16") if self.int8: try_set_flag("INT8") if not network.has_explicit_precision: if self.calibrator is not None: input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network) with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method self.calibrator.reset(input_metadata) config.int8_calibrator = self.calibrator try: config.set_calibration_profile(trt_profile) except: G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.") else: G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure " "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.") if self.sparse_weights: try_set_flag("SPARSE_WEIGHTS") if self.tactic_sources is not None: tactic_sources_flag = 0 for source in self.tactic_sources: tactic_sources_flag |= (1 << int(source)) try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources") try: if self.timing_cache_path: timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache") cache = config.create_timing_cache(timing_cache_data) else: # Create an empty timing cache by default so it will be populated during engine build. # This way, consumers of CreateConfig have the option to use the cache later. cache = config.create_timing_cache(b"") except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("load_timing_cache in CreateConfig") else: config.set_timing_cache(cache, ignore_mismatch=False) if self.algorithm_selector is not None: def set_algo_selector(): config.algorithm_selector = self.algorithm_selector try_run(set_algo_selector, "algorithm_selector") return config