Example #1
0
 def __init__(self,
              model_dir,
              softmax_layer='retrained_layer:0',
              namespace='classification'):
     tf.reset_default_graph()
     self.sess = tf.Session()
     self.namespace = namespace
     with tf.gfile.FastGFile(os.path.join(model_dir, 'state/model.pb'),
                             'rb') as f:
         graph_def = tf.GraphDef()
         graph_def.ParseFromString(f.read())
         _ = tf.import_graph_def(graph_def, name=namespace)
     softmax_path = softmax_layer
     if namespace:
         softmax_path = namespace + '/' + softmax_path
     self.softmax_tensor = self.sess.graph.get_tensor_by_name(softmax_path)
     labels = load_labels(model_dir)
     self.labels_by_node_id = {}
     for label_id in labels:
         label = labels[label_id]
         node_id = label.get('node_id', None)
         if node_id is None:
             raise Exception(
                 'No Softmax node_id is known for label {}, aborting'.
                 format(label_id))
         self.labels_by_node_id[node_id] = label
Example #2
0
    def __init__(self, base_model_path, scaffold_path, **kwargs):
        self.base_model_path = base_model_path
        self.model_name = base_model_path.split('/')[
            -1]  # ex inception_resnet_v2
        self.scaffold_path = scaffold_path
        self.settings = DEFAULT_SETTINGS

        for key in kwargs:
            self.settings[key] = kwargs[key]

        if 'base_checkpoint_path' not in self.settings:
            checkpoint_dir = os.path.abspath(
                os.path.join(self.base_model_path, 'state'))
            if tf.train.latest_checkpoint(checkpoint_dir) is None:
                base_path = os.path.join(self.base_model_path, 'state',
                                         self.model_name + '.ckpt')
            else:
                base_path = tf.train.latest_checkpoint(checkpoint_dir)
            self.settings['base_checkpoint_path'] = base_path

        self.labels = load_labels(scaffold_path)
        self.num_classes = len(self.labels)
        self.preprocess = preprocessing_factory.get_preprocessing(
            self.model_name, is_training=True)
        self.model_definition = nets_factory.get_network_fn(self.model_name,
                                                            self.num_classes,
                                                            is_training=True)
Example #3
0
 def __init__(self, base_model_path, scaffold_path, **kwargs):
     self.base_model_path = base_model_path
     self.scaffold_path = scaffold_path
     self.settings = DEFAULT_SETTINGS
     for key in kwargs:
         self.settings[key] = kwargs[key]
     if not self.settings.has_key('base_graph_path'):
         self.settings[
             'base_graph_path'] = base_model_path + '/state/model.pb'
     self.labels = load_labels(scaffold_path)
Example #4
0
    def __init__(self, model_dir, model_name):
        tf.reset_default_graph()
        self.model_dir = model_dir
        self.model_name = model_name  # ex inception_resnet_v2
        self.preprocess = preprocessing_factory.get_preprocessing(
            self.model_name, is_training=False)
        labels = load_labels(model_dir)
        self.labels_by_node_id = {}
        for label_id in labels:
            label = labels[label_id]
            node_id = label.get('node_id', None)
            if node_id is None:
                raise Exception(
                    'No Softmax node_id is known for label {}, aborting'.
                    format(label_id))
            self.labels_by_node_id[node_id] = label

        self.num_classes = len(self.labels_by_node_id)
        self.model_definition = nets_factory.get_network_fn(self.model_name,
                                                            self.num_classes,
                                                            is_training=False)
Example #5
0
    def __init__(self, model_file, options={}):

        settings = DEFAULT_SETTINGS
        settings['tau'] = 0.25
        settings['min_confidence'] = 0.2
        settings['show_suppressed'] = True
        for key in options:
            settings[key] = options[key]

        tf.reset_default_graph()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        # Only support 1 label for now
        self.label_meta = load_labels(model_file).values()[0]
        load_model_state(self.sess, model_file)
        self.pred_boxes = self.sess.graph.get_tensor_by_name('decoder_2/pred_boxes_test:0')
        self.pred_confidences = self.sess.graph.get_tensor_by_name('decoder_2/pred_confidences_test:0')
        self.x_in = self.sess.graph.get_tensor_by_name('fifo_queue_1_DequeueMany:0')

        if settings['use_rezoom']:
            self._set_rezoom(settings)
        self.settings = settings