class Yolow(object): config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) # sess = tf.Session() default_models = load_default_models() model_types = list(default_models.keys()) freeze_dir = 'data/pb/' tf_models = {model_types[0]: Yolov3, model_types[1]: Yolov3Tiny} def __init__(self, model_type, model_file, anchor_file, num_classes, input_size, labels, is_training=False): if model_type not in self.model_types: raise ValueError( 'model_type can only be either \'full\' or \'tiny\'.') elif not model_type: model_type = self.model_types[0] self.model_type = model_type if not model_file: model_file = './data/bin/{}'.format( self.default_models.get(model_type)) elif not os.path.exists(model_file): raise ValueError( 'model file {} does not exist.'.format(model_file)) self.model_file = model_file if '.pb' not in self.model_file: self.frozen_filename = '_'.join( ['frozen', os.path.basename(self.model_file).split('.')[0]]) self.frozen_filename = self.freeze_dir + self.frozen_filename + '.pb' if not input_size: input_size = 416 if type(input_size) is int: self.input_size = input_size, input_size else: self.input_size = input_size self.labels = labels self.imer = Imager(self.input_size, self.labels) if os.path.exists(self.frozen_filename): self.defrost() self.input = tf.get_default_graph().get_tensor_by_name( 'import/input:0') self.output = tf.get_default_graph().get_tensor_by_name( 'import/detections/output:0') else: if not anchor_file: anchor_file = 'data/anchors/' + self.model_type + '.txt' elif not os.path.exists(anchor_file): raise ValueError( '{} anchor file does not exist.'.format(anchor_file)) self.anchor_file = anchor_file self.num_classes = num_classes self.is_training = is_training self.input = tf.placeholder( tf.float32, [None, self.input_size[0], self.input_size[1], 3], 'input') self.model = self.tf_models[self.model_type](self.input, self.num_classes, self.input_size, self.anchor_file, self.is_training) with tf.variable_scope('detections'): self.output = self.model.graph() self.loader = WeightLoader(tf.global_variables('detections'), self.model_file) # self.sess.run(tf.global_variables_initializer()) self.sess.run(self.loader.load_now()) self.freeze() def set_input(self, images): if type(images) == str: self.imer.imset_from_path(images) else: self.imer.imset(images) def predict(self, confidence_theshold=.6, iou_threshold=.5): input_list = self.imer.preprocess() feed_dict = {self.input: input_list} batch_detections = self.sess.run(self.output, feed_dict) pred_list = predict(batch_detections, confidence_theshold, iou_threshold) return self.imer.visualise_preds(pred_list) def freeze(self): graph_def = tf.graph_util.convert_variables_to_constants( sess=self.sess, input_graph_def=tf.get_default_graph().as_graph_def(), output_node_names=['detections/output']) if not os.path.exists(self.freeze_dir): os.makedirs(self.freeze_dir) with tf.gfile.GFile(self.frozen_filename, 'wb') as f: f.write(graph_def.SerializeToString()) def defrost(self): print('Found frozen model {}, defrost and use!'.format( self.frozen_filename)) with tf.gfile.GFile(self.frozen_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def)