def load_trained_model(self): self.detection_graph = tf.Graph() with self.detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(self.model_file_name, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ops = tf.get_default_graph().get_operations() all_tensor_names = { output.name for op in ops for output in op.outputs } for key in [ 'num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks' ]: tensor_name = key + ':0' if tensor_name in all_tensor_names: self.tensor_dict[key] = tf.get_default_graph( ).get_tensor_by_name(tensor_name) label_map = label_map_util.load_labelmap(self.label_map_file_name) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=90, use_display_name=True) self.category_index = label_map_util.create_category_index(categories)
def _load_label_map(self): print('Loading label map...') label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) return category_index
def get_category_index(self, labels_file, num_classes): """ Create category index. Adapted from TF object recognition tutorial. """ label_map = label_map_util.load_labelmap(labels_file) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=num_classes, use_display_name=True) category_index = label_map_util.create_category_index(categories) return category_index
import tensorflow as tf import numpy as np import os from tensorflow.models.research.object_detection.utils import label_map_util from tensorflow.models.research.object_detection.utils import visualization_utils as vis_util # List of the strings that is used to add correct label for each box. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') NUM_CLASSES = 90 # Size, in inches, of the output images. IMAGE_SIZE = (12, 8) ##################### Loading label map print('Loading label map...') label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) def get_frozen_graph(graph_file): """Read Frozen Graph file from disk.""" with tf.gfile.FastGFile(graph_file, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) return graph_def # The TensorRT inference graph file downloaded from Colab or your local machine. pb_fname = "./model/trt_graph.pb"