def infer_impl(self, feed_dict): G_LOGGER.extra_verbose("Received feed_dict: {:}".format(feed_dict)) start = time.time() inference_outputs = self.sess.run(self.output_names, feed_dict=feed_dict, options=self.run_options, run_metadata=self.run_metadata) end = time.time() out_dict = OrderedDict() for name, out in zip(self.output_names, inference_outputs): out_dict[name] = out self.inference_time = end - start if self.timeline_dir is not None: from tensorflow.python.client import timeline t1 = timeline.Timeline(self.run_metadata.step_stats) util.save_file( contents=t1.generate_chrome_trace_format(), dest=os.path.join(self.timeline_dir, "run-{:}".format(self.num_inferences)), mode="w", ) self.num_inferences += 1 return out_dict
def call_impl(self): """ Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ (graph, outputs), _ = util.invoke_if_callable(self._graph) if self.path: util.save_file(graph.as_graph_def().SerializeToString(), dest=self.path) if self.tensorboard_dir: G_LOGGER.info("Writing tensorboard events to {:}".format( self.tensorboard_dir)) train_writer = tf.compat.v1.summary.FileWriter( self.tensorboard_dir) train_writer.add_graph(graph) if self.engine_dir is not None: graphdef = graph.as_graph_def() segment_number = 0 for node in graphdef.node: if node.op == "TRTEngineOp": engine = node.attr["serialized_segment"].s if self.engine_dir is not None: util.save_file( contents=engine, dest=os.path.join( self.engine_dir, "segment-{:}".format(segment_number))) segment_number += 1 return graph, outputs
def call_impl(self): """ Returns: bytes: The bytes saved. """ obj, _ = util.invoke_if_callable(self._bytes) util.save_file(obj, self._path) return obj
def write_calibration_cache(self, cache): self.cache_contents = cache.tobytes() self.has_cached_scales = True if self._cache is None: return try: util.save_file(contents=self.cache_contents, dest=self._cache, description="calibration cache") except: G_LOGGER.warning("Could not write to calibration cache: {:}".format(self._cache))
def call_impl(self): """ Returns: trt.ICudaEngine: The engine that was saved. """ engine, owns_engine = util.invoke_if_callable(self._engine) with contextlib.ExitStack() as stack: if owns_engine: stack.enter_context(util.FreeOnException([engine])) util.save_file(contents=bytes_from_engine(engine), dest=self.path, description="engine") return engine
def write_calibration_cache(self, cache): self.cache_contents = cache.tobytes() self.has_cached_scales = True if self._cache is None: return try: util.save_file(contents=self.cache_contents, dest=self._cache, description="calibration cache") except Exception as err: G_LOGGER.error( "Could not write to calibration cache: {:}.\nNote: Error was: {:}" .format(self._cache, err))
def call_impl(self): """ Returns: bytes: The serialized engine that was created. """ # If network is a callable, then we own its return value ret, owns_network = util.invoke_if_callable(self._network) builder, network, parser = util.unpack_args(ret, num=3) if builder is None or network is None: G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, " "but received: ({:}, {:})".format(builder, network)) with contextlib.ExitStack() as stack: if owns_network: stack.enter_context(builder) stack.enter_context(network) if parser is not None: stack.enter_context(parser) else: provided = "Builder and Network" if parser is None else "Builder, Network, and Parser" G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. " "Please ensure that they are freed.".format(provided)) config, owns_config = util.invoke_if_callable(self._config, builder, network) if owns_config: stack.enter_context(config) else: G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume " "ownership. Please ensure it is freed.") try: config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit. except AttributeError: pass else: stack.enter_context(config.int8_calibrator) network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs" G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode))) G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config))) try: engine_bytes = builder.build_serialized_network(network, config) except AttributeError: engine = builder.build_engine(network, config) if not engine: G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly") stack.enter_context(engine) engine_bytes = engine.serialize() if not engine_bytes: G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly") try: timing_cache = config.get_timing_cache() except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork") else: if timing_cache and self.timing_cache_path: with timing_cache.serialize() as buffer: util.save_file(buffer, self.timing_cache_path, description="tactic timing cache") return engine_bytes