Ejemplo n.º 1
0
 def call_impl(self):
     """
     Returns:
         Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
     """
     graph = tf_util.load_graph(self.path)
     return graph, tf_util.get_graph_output_names(graph)
Ejemplo n.º 2
0
    def __call__(self):
        """
        Loads a TensorFlow frozen model.

        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        graph = tf_util.load_graph(self.path)
        return graph, tf_util.get_graph_output_names(graph)
Ejemplo n.º 3
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """

        from tensorflow.python import keras
        from tensorflow.python.keras import backend

        model = keras.models.load_model(self.path)
        graph = backend.get_session().graph
        return graph, tf_util.get_graph_output_names(graph)
Ejemplo n.º 4
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        from tensorflow.contrib import tensorrt as tf_trt

        (graph, output_names), _ = util.invoke_if_callable(self._graph)

        precision_mode = "FP16" if self.fp16 else "FP32"
        precision_mode = "INT8" if self.int8 else precision_mode

        G_LOGGER.info(
            "For TF-TRT, using outputs={:}, max_workspace_size_bytes={:}, max_batch_size={:}, "
            "minimum_segment_size={:}, is_dynamic_op={:}, precision_mode={:}".
            format(
                output_names,
                self.max_workspace_size,
                self.max_batch_size,
                self.minimum_segment_size,
                self.is_dynamic_op,
                precision_mode,
            ))

        graphdef = tf_trt.create_inference_graph(
            graph.as_graph_def(),
            outputs=output_names,
            max_workspace_size_bytes=self.max_workspace_size,
            max_batch_size=self.max_batch_size,
            minimum_segment_size=self.minimum_segment_size,
            is_dynamic_op=self.is_dynamic_op,
            precision_mode=precision_mode,
        )

        segment_number = 0
        for node in graphdef.node:
            if node.op == "TRTEngineOp":
                engine = node.attr["serialized_segment"].s
                segment_number += 1
        G_LOGGER.info(
            "Found {:} engines in TFTRT graph".format(segment_number))

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graphdef, name="")
            return graph, tf_util.get_graph_output_names(graph)
Ejemplo n.º 5
0
    def call_impl(self):
        """
        Returns:
            Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs.
        """
        # If `name` is not provided, this expects that the directory contains a `checkpoint` file with the contents:
        #
        # model_checkpoint_path: "model"
        # all_model_checkpoint_paths: "model"
        #
        # where "model" is the checkpoint name
        if not os.path.isdir(self.dir):
            G_LOGGER.warning(
                "Specified checkpoint directory: {:} does not look like a directory."
                .format(self.dir))

        if self.name is None:
            G_LOGGER.verbose(
                "Checkpoint name was not explicitly provided, searching for `checkpoint` file"
            )
            checkpoint = tf.train.get_checkpoint_state(self.dir)
            if checkpoint is None:
                ckpt_file_contents = '\nmodel_checkpoint_path: "model"\nall_model_checkpoint_paths: "model"\n'
                G_LOGGER.critical(
                    "Checkpoint directory: {:} does not contain a `checkpoint` file, and the checkpoint name was "
                    "not provided. Please either create a checkpoint file with the contents:\n{:} "
                    "\nWhere `model` is the name of the checkpoint, or explicitly provide the name with "
                    "--ckpt, not including file extensions".format(
                        self.dir, ckpt_file_contents))
            input_checkpoint = checkpoint.model_checkpoint_path
        else:
            input_checkpoint = os.path.join(self.dir, self.name)

        meta_file = input_checkpoint + ".meta"
        with tf.Graph().as_default() as graph, tf.compat.v1.Session(
                graph=graph).as_default() as sess:
            saver = tf.compat.v1.train.import_meta_graph(meta_file,
                                                         clear_devices=True)
            saver.restore(sess, input_checkpoint)
            return graph, tf_util.get_graph_output_names(graph)