def load_graph(path): """ Loads a TensorFlow frozen model. Args: path (Union[str, tf.Graph, tf.GraphDef]): A path to the frozen model, or a frozen TensorFlow graph or graphdef. Returns: tf.Graph: The TensorFlow graph """ if isinstance(path, tf.Graph): return path if isinstance(path, str): graphdef = tf.compat.v1.GraphDef() import google try: graphdef.ParseFromString( util.load_file(path, description="GraphDef")) except google.protobuf.message.DecodeError: G_LOGGER.backtrace() G_LOGGER.critical( "Could not import TensorFlow GraphDef from: {:}. Is this a valid TensorFlow model?" .format(path)) elif isinstance(path, tf.compat.v1.GraphDef): graphdef = path with tf.Graph().as_default() as graph: tf.import_graph_def(graphdef, name="") return graph
def load_from_cache(): if self._cache is None or not util.get_file_size(self._cache): return None try: return util.load_file(self._cache, description="calibration cache") except: G_LOGGER.warning("Could not read from calibration cache: {:}".format(self._cache)) return None
def load_from_cache(): if self._cache is None or not util.get_file_size(self._cache): return None try: return util.load_file(self._cache, description="calibration cache") except Exception as err: G_LOGGER.error( "Could not read from calibration cache: {:}\nNote: Error was: {:}" .format(self._cache, err)) return None
def call_impl(self): """ Returns: bytes: The contents of the file. """ return util.load_file(self._path, description="bytes")
def call_impl(self, builder, network): """ Args: builder (trt.Builder): The TensorRT builder to use to create the configuration. network (trt.INetworkDefinition): The TensorRT network for which to create the config. The network is used to automatically create a default optimization profile if none are provided. Returns: trt.IBuilderConfig: The TensorRT builder configuration. """ with util.FreeOnException([builder.create_builder_config()]) as (config, ): def try_run(func, name): try: return func() except AttributeError: trt_util.fail_unavailable("{:} in CreateConfig".format(name)) def try_set_flag(flag_name): return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower()) with G_LOGGER.indent(): G_LOGGER.verbose("Setting TensorRT Optimization Profiles") profiles = copy.deepcopy(self.profiles) for profile in profiles: # Last trt_profile is used for set_calibration_profile. trt_profile = profile.fill_defaults(network).to_trt(builder, network) config.add_optimization_profile(trt_profile) G_LOGGER.info("Configuring with profiles: {:}".format(profiles)) config.max_workspace_size = int(self.max_workspace_size) if self.strict_types: try_set_flag("STRICT_TYPES") if self.tf32: try_set_flag("TF32") else: # TF32 is on by default with contextlib.suppress(AttributeError): config.clear_flag(trt.BuilderFlag.TF32) if self.fp16: try_set_flag("FP16") if self.int8: try_set_flag("INT8") if not network.has_explicit_precision: if self.calibrator is not None: input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network) with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method self.calibrator.reset(input_metadata) config.int8_calibrator = self.calibrator try: config.set_calibration_profile(trt_profile) except: G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.") else: G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure " "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.") if self.sparse_weights: try_set_flag("SPARSE_WEIGHTS") if self.tactic_sources is not None: tactic_sources_flag = 0 for source in self.tactic_sources: tactic_sources_flag |= (1 << int(source)) try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources") try: if self.timing_cache_path: timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache") cache = config.create_timing_cache(timing_cache_data) else: # Create an empty timing cache by default so it will be populated during engine build. # This way, consumers of CreateConfig have the option to use the cache later. cache = config.create_timing_cache(b"") except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("load_timing_cache in CreateConfig") else: config.set_timing_cache(cache, ignore_mismatch=False) if self.algorithm_selector is not None: def set_algo_selector(): config.algorithm_selector = self.algorithm_selector try_run(set_algo_selector, "algorithm_selector") return config