def predict(): fasterRCNN = Network() fasterRCNN.build(is_training=False) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt")) print("Model restored.") base_extractor = VGG16(include_top=False) extractor = Model(inputs=base_extractor.input, outputs=base_extractor.get_layer('block5_conv3').output) predict_img_names = os.listdir(PREDICT_IMG_DATA_PATH) for predict_img_name in predict_img_names: img_data, img_info = get_predict_data(predict_img_name) features = extractor.predict(img_data, steps=1) rois, scores, regression_parameter = sess.run( [fasterRCNN._predictions["rois"], fasterRCNN._predictions["cls_prob"], fasterRCNN._predictions["bbox_pred"]], feed_dict={fasterRCNN.feature_map: features, fasterRCNN.image_info: img_info}) boxes = rois[:, 1:5] / img_info[2] scores = np.reshape(scores, [scores.shape[0], -1]) regression_parameter = np.reshape(regression_parameter, [regression_parameter.shape[0], -1]) pred_boxes = bbox_transform_inv(boxes, regression_parameter) pred_boxes = clip_boxes(pred_boxes, [img_info[0] / img_info[2], img_info[1] / img_info[2]]) result_list = [] for class_index, class_name in enumerate(CLASSES[1:]): class_index += 1 # 因为跳过了背景类别 cls_boxes = pred_boxes[:, 4 * class_index:4 * (class_index + 1)] # TODO: cls_scores = scores[:, class_index] detections = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(detections, NMS_THRESH) detections = detections[keep, :] inds = np.where(detections[:, -1] >= CONF_THRESH)[0] # 筛选结果 for i in inds: result_for_a_class = [] bbox = detections[i, :4] score = detections[i, -1] result_for_a_class.append(predict_img_name) result_for_a_class.append(class_name) result_for_a_class.append(score) for coordinate in bbox: result_for_a_class.append(coordinate) result_list.append(result_for_a_class) # result_for_a_class = [fileName,class_name,score,x1,y1,x2,y2] if len(result_list) == 0: continue if TXT_RESULT_WANTED: write_txt_result(result_list) if IS_VISIBLE: visualization(result_list)
def train(): fasterRCNN = Network() fasterRCNN.build(is_training=True) train_op = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9).minimize( fasterRCNN._losses['total_loss']) init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) base_extractor = VGG16(include_top=False) extractor = Model( inputs=base_extractor.input, outputs=base_extractor.get_layer('block5_conv3').output) train_img_names = os.listdir(TRAIN_IMG_DATA_PATH) trained_times = 0 for epoch in range(1, MAX_EPOCH + 1): random.shuffle(train_img_names) for train_img_name in train_img_names: img_data, boxes, img_info = get_train_data(train_img_name) features = extractor.predict(img_data, steps=1) sess.run(train_op, feed_dict={ fasterRCNN.feature_map: features, fasterRCNN.gt_boxes: boxes, fasterRCNN.image_info: img_info }) trained_times += 1 if trained_times % 10 == 0: total_loss = sess.run(fasterRCNN._losses['total_loss'], feed_dict={ fasterRCNN.feature_map: features, fasterRCNN.gt_boxes: boxes, fasterRCNN.image_info: img_info }) print('epoch:{}, trained_times:{}, loss:{}'.format( epoch, trained_times, total_loss)) if epoch % 10 == 0: save_path = saver.save( sess, os.path.join(CHECKPOINTS_PATH, "model_" + str(epoch) + ".ckpt")) print("Model saved in path: %s" % save_path) save_path = saver.save( sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt")) print("Model saved in path: %s" % save_path)