Beispiel #1
0
def predict():
    fasterRCNN = Network()
    fasterRCNN.build(is_training=False)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt"))
        print("Model restored.")
        base_extractor = VGG16(include_top=False)
        extractor = Model(inputs=base_extractor.input, outputs=base_extractor.get_layer('block5_conv3').output)
        predict_img_names = os.listdir(PREDICT_IMG_DATA_PATH)

        for predict_img_name in predict_img_names:
            img_data, img_info = get_predict_data(predict_img_name)
            features = extractor.predict(img_data, steps=1)
            rois, scores, regression_parameter = sess.run(
                [fasterRCNN._predictions["rois"], fasterRCNN._predictions["cls_prob"],
                 fasterRCNN._predictions["bbox_pred"]],
                feed_dict={fasterRCNN.feature_map: features,
                           fasterRCNN.image_info: img_info})

            boxes = rois[:, 1:5] / img_info[2]
            scores = np.reshape(scores, [scores.shape[0], -1])
            regression_parameter = np.reshape(regression_parameter, [regression_parameter.shape[0], -1])
            pred_boxes = bbox_transform_inv(boxes, regression_parameter)
            pred_boxes = clip_boxes(pred_boxes, [img_info[0] / img_info[2], img_info[1] / img_info[2]])

            result_list = []
            for class_index, class_name in enumerate(CLASSES[1:]):
                class_index += 1  # 因为跳过了背景类别
                cls_boxes = pred_boxes[:, 4 * class_index:4 * (class_index + 1)]  # TODO:
                cls_scores = scores[:, class_index]
                detections = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
                keep = nms(detections, NMS_THRESH)
                detections = detections[keep, :]

                inds = np.where(detections[:, -1] >= CONF_THRESH)[0]  # 筛选结果
                for i in inds:
                    result_for_a_class = []
                    bbox = detections[i, :4]
                    score = detections[i, -1]
                    result_for_a_class.append(predict_img_name)
                    result_for_a_class.append(class_name)
                    result_for_a_class.append(score)
                    for coordinate in bbox:
                        result_for_a_class.append(coordinate)
                    result_list.append(result_for_a_class)
                    # result_for_a_class = [fileName,class_name,score,x1,y1,x2,y2]
            if len(result_list) == 0:
                continue

            if TXT_RESULT_WANTED:
                write_txt_result(result_list)

            if IS_VISIBLE:
                visualization(result_list)
Beispiel #2
0
def train():
    fasterRCNN = Network()
    fasterRCNN.build(is_training=True)
    train_op = tf.train.MomentumOptimizer(learning_rate=0.001,
                                          momentum=0.9).minimize(
                                              fasterRCNN._losses['total_loss'])
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)

        base_extractor = VGG16(include_top=False)
        extractor = Model(
            inputs=base_extractor.input,
            outputs=base_extractor.get_layer('block5_conv3').output)
        train_img_names = os.listdir(TRAIN_IMG_DATA_PATH)
        trained_times = 0

        for epoch in range(1, MAX_EPOCH + 1):
            random.shuffle(train_img_names)
            for train_img_name in train_img_names:
                img_data, boxes, img_info = get_train_data(train_img_name)
                features = extractor.predict(img_data, steps=1)
                sess.run(train_op,
                         feed_dict={
                             fasterRCNN.feature_map: features,
                             fasterRCNN.gt_boxes: boxes,
                             fasterRCNN.image_info: img_info
                         })

                trained_times += 1
                if trained_times % 10 == 0:
                    total_loss = sess.run(fasterRCNN._losses['total_loss'],
                                          feed_dict={
                                              fasterRCNN.feature_map: features,
                                              fasterRCNN.gt_boxes: boxes,
                                              fasterRCNN.image_info: img_info
                                          })
                    print('epoch:{}, trained_times:{}, loss:{}'.format(
                        epoch, trained_times, total_loss))

            if epoch % 10 == 0:
                save_path = saver.save(
                    sess,
                    os.path.join(CHECKPOINTS_PATH,
                                 "model_" + str(epoch) + ".ckpt"))
                print("Model saved in path: %s" % save_path)
        save_path = saver.save(
            sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt"))
        print("Model saved in path: %s" % save_path)