Example #1
0
    def infer_impl(self, feed_dict):
        G_LOGGER.extra_verbose("Received feed_dict: {:}".format(feed_dict))
        start = time.time()
        inference_outputs = self.sess.run(self.output_names,
                                          feed_dict=feed_dict,
                                          options=self.run_options,
                                          run_metadata=self.run_metadata)
        end = time.time()

        out_dict = OrderedDict()
        for name, out in zip(self.output_names, inference_outputs):
            out_dict[name] = out
        self.inference_time = end - start

        if self.timeline_dir is not None:
            from tensorflow.python.client import timeline

            t1 = timeline.Timeline(self.run_metadata.step_stats)

            util.save_file(
                contents=t1.generate_chrome_trace_format(),
                dest=os.path.join(self.timeline_dir,
                                  "run-{:}".format(self.num_inferences)),
                mode="w",
            )
        self.num_inferences += 1

        return out_dict
Example #2
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        (graph, outputs), _ = util.invoke_if_callable(self._graph)

        if self.path:
            util.save_file(graph.as_graph_def().SerializeToString(),
                           dest=self.path)
        if self.tensorboard_dir:
            G_LOGGER.info("Writing tensorboard events to {:}".format(
                self.tensorboard_dir))
            train_writer = tf.compat.v1.summary.FileWriter(
                self.tensorboard_dir)
            train_writer.add_graph(graph)

        if self.engine_dir is not None:
            graphdef = graph.as_graph_def()
            segment_number = 0
            for node in graphdef.node:
                if node.op == "TRTEngineOp":
                    engine = node.attr["serialized_segment"].s
                    if self.engine_dir is not None:
                        util.save_file(
                            contents=engine,
                            dest=os.path.join(
                                self.engine_dir,
                                "segment-{:}".format(segment_number)))
                    segment_number += 1

        return graph, outputs
Example #3
0
 def call_impl(self):
     """
     Returns:
         bytes: The bytes saved.
     """
     obj, _ = util.invoke_if_callable(self._bytes)
     util.save_file(obj, self._path)
     return obj
Example #4
0
        def write_calibration_cache(self, cache):
            self.cache_contents = cache.tobytes()
            self.has_cached_scales = True

            if self._cache is None:
                return

            try:
                util.save_file(contents=self.cache_contents, dest=self._cache, description="calibration cache")
            except:
                G_LOGGER.warning("Could not write to calibration cache: {:}".format(self._cache))
Example #5
0
    def call_impl(self):
        """
        Returns:
            trt.ICudaEngine: The engine that was saved.
        """
        engine, owns_engine = util.invoke_if_callable(self._engine)

        with contextlib.ExitStack() as stack:
            if owns_engine:
                stack.enter_context(util.FreeOnException([engine]))

            util.save_file(contents=bytes_from_engine(engine), dest=self.path, description="engine")
            return engine
Example #6
0
        def write_calibration_cache(self, cache):
            self.cache_contents = cache.tobytes()
            self.has_cached_scales = True

            if self._cache is None:
                return

            try:
                util.save_file(contents=self.cache_contents,
                               dest=self._cache,
                               description="calibration cache")
            except Exception as err:
                G_LOGGER.error(
                    "Could not write to calibration cache: {:}.\nNote: Error was: {:}"
                    .format(self._cache, err))
Example #7
0
    def call_impl(self):
        """
        Returns:
            bytes: The serialized engine that was created.
        """
        # If network is a callable, then we own its return value
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        if builder is None or network is None:
            G_LOGGER.critical("Expected to recevie a (builder, network) tuple for the `network` parameter, "
                              "but received: ({:}, {:})".format(builder, network))

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(builder)
                stack.enter_context(network)
                if parser is not None:
                    stack.enter_context(parser)
            else:
                provided = "Builder and Network" if parser is None else "Builder, Network, and Parser"
                G_LOGGER.verbose("{:} were provided directly instead of via a Callable. This loader will not assume ownership. "
                                 "Please ensure that they are freed.".format(provided))

            config, owns_config = util.invoke_if_callable(self._config, builder, network)
            if owns_config:
                stack.enter_context(config)
            else:
                G_LOGGER.verbose("Builder configuration was provided directly instead of via a Callable. This loader will not assume "
                                 "ownership. Please ensure it is freed.")

            try:
                config.int8_calibrator.__enter__ # Polygraphy calibrator frees device buffers on exit.
            except AttributeError:
                pass
            else:
                stack.enter_context(config.int8_calibrator)

            network_log_mode = "full" if G_LOGGER.severity <= G_LOGGER.ULTRA_VERBOSE else "attrs"
            G_LOGGER.super_verbose(lambda: ("Displaying TensorRT Network:\n" + trt_util.str_from_network(network, mode=network_log_mode)))

            G_LOGGER.start("Building engine with configuration:\n{:}".format(trt_util.str_from_config(config)))

            try:
                engine_bytes = builder.build_serialized_network(network, config)
            except AttributeError:
                engine = builder.build_engine(network, config)
                if not engine:
                    G_LOGGER.critical("Invalid Engine. Please ensure the engine was built correctly")
                stack.enter_context(engine)
                engine_bytes = engine.serialize()

            if not engine_bytes:
                G_LOGGER.critical("Invalid Engine. Please ensure the engine_bytes was built correctly")

            try:
                timing_cache = config.get_timing_cache()
            except AttributeError:
                if self.timing_cache_path:
                    trt_util.fail_unavailable("save_timing_cache in EngineBytesFromNetwork")
            else:
                if timing_cache and self.timing_cache_path:
                    with timing_cache.serialize() as buffer:
                        util.save_file(buffer, self.timing_cache_path, description="tactic timing cache")

            return engine_bytes