def call_impl(self): """ Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ graph = tf_util.load_graph(self.path) return graph, tf_util.get_graph_output_names(graph)
def __call__(self): """ Loads a TensorFlow frozen model. Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ graph = tf_util.load_graph(self.path) return graph, tf_util.get_graph_output_names(graph)
def call_impl(self): """ Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ from tensorflow.python import keras from tensorflow.python.keras import backend model = keras.models.load_model(self.path) graph = backend.get_session().graph return graph, tf_util.get_graph_output_names(graph)
def call_impl(self): """ Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ from tensorflow.contrib import tensorrt as tf_trt (graph, output_names), _ = util.invoke_if_callable(self._graph) precision_mode = "FP16" if self.fp16 else "FP32" precision_mode = "INT8" if self.int8 else precision_mode G_LOGGER.info( "For TF-TRT, using outputs={:}, max_workspace_size_bytes={:}, max_batch_size={:}, " "minimum_segment_size={:}, is_dynamic_op={:}, precision_mode={:}". format( output_names, self.max_workspace_size, self.max_batch_size, self.minimum_segment_size, self.is_dynamic_op, precision_mode, )) graphdef = tf_trt.create_inference_graph( graph.as_graph_def(), outputs=output_names, max_workspace_size_bytes=self.max_workspace_size, max_batch_size=self.max_batch_size, minimum_segment_size=self.minimum_segment_size, is_dynamic_op=self.is_dynamic_op, precision_mode=precision_mode, ) segment_number = 0 for node in graphdef.node: if node.op == "TRTEngineOp": engine = node.attr["serialized_segment"].s segment_number += 1 G_LOGGER.info( "Found {:} engines in TFTRT graph".format(segment_number)) with tf.Graph().as_default() as graph: tf.import_graph_def(graphdef, name="") return graph, tf_util.get_graph_output_names(graph)
def call_impl(self): """ Returns: Tuple[tf.Graph, Sequence[str]]: The TensorFlow graph, and the names of its outputs. """ # If `name` is not provided, this expects that the directory contains a `checkpoint` file with the contents: # # model_checkpoint_path: "model" # all_model_checkpoint_paths: "model" # # where "model" is the checkpoint name if not os.path.isdir(self.dir): G_LOGGER.warning( "Specified checkpoint directory: {:} does not look like a directory." .format(self.dir)) if self.name is None: G_LOGGER.verbose( "Checkpoint name was not explicitly provided, searching for `checkpoint` file" ) checkpoint = tf.train.get_checkpoint_state(self.dir) if checkpoint is None: ckpt_file_contents = '\nmodel_checkpoint_path: "model"\nall_model_checkpoint_paths: "model"\n' G_LOGGER.critical( "Checkpoint directory: {:} does not contain a `checkpoint` file, and the checkpoint name was " "not provided. Please either create a checkpoint file with the contents:\n{:} " "\nWhere `model` is the name of the checkpoint, or explicitly provide the name with " "--ckpt, not including file extensions".format( self.dir, ckpt_file_contents)) input_checkpoint = checkpoint.model_checkpoint_path else: input_checkpoint = os.path.join(self.dir, self.name) meta_file = input_checkpoint + ".meta" with tf.Graph().as_default() as graph, tf.compat.v1.Session( graph=graph).as_default() as sess: saver = tf.compat.v1.train.import_meta_graph(meta_file, clear_devices=True) saver.restore(sess, input_checkpoint) return graph, tf_util.get_graph_output_names(graph)