def t_net(prefix, epoch, batch_size, test_mode, thresh, min_face_size=20,
          stride=2, slide_window=False):
    detectors = [None, None, None]
    # 生成指定模型的困难样本
    print("Test model: ", test_mode)
    model_path = ['%s-%s' % (x, y) for x, y in zip(prefix, epoch)]
    # PNet模型
    print(model_path[0])
    if slide_window:
        p_net = Detector(P_Net, 12, batch_size[0], model_path[0])
    else:
        p_net = FcnDetector(P_Net, model_path[0])
    detectors[0] = p_net
    
    # RNet模型
    if test_mode in ["RNet", "ONet"]:
        print("=================   {}   =================".format(test_mode))
        r_net = Detector(R_Net, 24, batch_size[1], model_path[1])
        detectors[1] = r_net
    
    # ONet模型,这个模式下生成的样本主要用来观察,而不是训练
    if test_mode == "ONet":
        print("==================   {}   ================".format(test_mode))
        o_net = Detector(O_Net, 48, batch_size[2], model_path[2])
        detectors[2] = o_net
    
    # 读取bounding box的ground truth及图片,type:dict,include key 'images' and 'bboxes'
    data = read_bboxes_data(path_config.point2_train_txt_path, path_config.images_dir)
    
    mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,
                                   stride=stride, threshold=thresh, slide_window=slide_window)
    
    print('加载原始图片数据,以进行检测及生成困难样本...')
    test_data = TestLoader(data['images'])
    print('加载完成, 开始检测...')
    detections, _ = mtcnn_detector.detect_images(test_data)
    print('检测完成!')
    
    # 保存检测结果
    if test_mode == "PNet":
        save_path = path_config.rnet_save_hard_path
    elif test_mode == "RNet":
        save_path = path_config.onet_save_hard_path
    else:
        raise ValueError('网络类型(--test_mode)错误!')
    print('保存检测结果的路径为:', save_path)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    
    save_file = os.path.join(save_path, "detections.pkl")
    with open(save_file, 'wb') as f:
        pickle.dump(detections, f, 1)
    print("%s测试完成,开始生成困难样本..." % test_mode)
    save_hard_example(data, test_mode, save_path)
示例#2
0
class PNetTester(object):
    """ 测试PNet模型 """
    
    def __init__(self, models_path, image_path, ground_truth_file):
        """ 初始化:加载PNet模型、待测试图片路径及对应ground truth
        :param models_path: [PNet, RNet, ONet]模型路径
        :param image_path: 待测试图片路径
        :param ground_truth_file: 待测图片目标ground_truth文件
        """
        # 初始化PNet模型检测器
        pnet_model = None
        rnet_model = None
        onet_model = None
        if models_path[0] is not None:
            pnet_model = FcnDetector(P_Net, models_path[0])
        if models_path[1] is not None:
            rnet_model = Detector(R_Net, 24, 256, models_path[1])
        if models_path[2] is not None:
            onet_model = Detector(O_Net, 48, 16, models_path[2])
            
        self.detector = MtcnnDetector([pnet_model, rnet_model, onet_model], min_face_size=20, stride=2,
                                      threshold=[0.7, 0.7, 0.7], scale_factor=0.79, slide_window=False)
        # 初始化ground truth
        self.ground_map = dict()
        valid_image_path = list()
        with open(ground_truth_file, 'r') as truth_file:
            for ground_truth in truth_file:
                ground_truth = ground_truth.strip().split(' ')
                self.ground_map[ground_truth[0]] = np.array([float(_) for _ in ground_truth[1:]])
                valid_image_path.append(ground_truth[0])
                
        # 初始化图片加载器
        if os.path.isdir(image_path):
            images_path = PNetTester.search_file(image_path)
        elif os.path.isfile(image_path) and image_path.endswith('.jpg'):
            images_path = [image_path]
        
        self.images_path = list()
        for image_path in images_path:
            if os.path.basename(image_path) in valid_image_path:
                self.images_path.append(image_path)
        print('待检测图片数量:', len(self.images_path))
        self.test_loader = TestLoader(self.images_path)
        
        return
    
    @staticmethod
    def search_file(search_path):
        """在指定目录search_path下,递归目录搜索jpg文件
        :param search_path: 指定的搜索目录,如:./2018年收集的素材并已校正
        :return: 该目录下所有jpg文件的路径组成的list
        """
        jpg_path_list = list()
        # 获取:1.父目录绝对路径 2.所有文件夹名字(不含路径) 3.所有文件名字
        for root_path, dir_names, file_names in os.walk(search_path):
            # 收集符合条件的文件名
            for filename in file_names:
                if filename.endswith('.jpg') and filename.find(' ') == -1:
                    jpg_path_list.append(os.path.join(root_path, filename))
        return jpg_path_list

    def test(self, threshold, model_iter):
        """
        :param threshold:
        :param model_iter:
        :return:
        """
        all_boxes, landmark = self.detector.detect_images(self.test_loader)
        
        hard_samples = list()
        recall = 0  # 召回率:TP / (TP + TN)
        acc_pos = 0
        acc_all = 0
        precision = 0  # 精确率:TP / (TP + FP)
        save_path = os.path.join(os.path.dirname(self.images_path[0]), '..', 'result')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        for index, image_path in enumerate(self.images_path):
            image_name = os.path.basename(image_path)
            ground_truth = self.ground_map[image_name]
            if len(all_boxes[index]) == 0:
                print('图片{}检测不到车牌'.format(image_name))
                continue
            # 计算iou,并画框
            iou = np.ones((len(all_boxes[index]),))
            gt_boxes = np.array([ground_truth])
            for j, box in enumerate(all_boxes[index]):
                iou[j] = IoU(box, gt_boxes)
            '''
            # 画图
            im = cv2.imread(image_path)
            for j, box in enumerate(all_boxes[index]):
                # if image_name == '20180929172716720_23609_dqp001_甘A5T470.jpg':
                #    pdb.set_trace()
                # 绘制iou大于阈值的pos框
                if iou[j] > threshold:
                    cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                                  (255, 0, 0), 2)
                    cv2.putText(im, '{:s}|{:.2f}|{:.2f}'.format('p', box[4], iou[j]),
                                (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX,
                                1, (255, 0, 0))
                    for k in range(5):
                        cv2.circle(im, (landmark[index][j][2*k], landmark[index][j][2*k+1]), 1, (0, 0, 255), 4)
                    
            # 绘制ground truth
            cv2.rectangle(im, (int(ground_truth[0]), int(ground_truth[1])),
                          (int(ground_truth[2]), int(ground_truth[3])),
                          (0, 255, 0), 2)
            cv2.imwrite(os.path.join(save_path, os.path.splitext(image_name)[0] + '_' + model_iter + '.jpg'), im)
            
            print('IoU:\n', iou)
            print('average iou = {}'.format(sum(iou) / sum(iou != 0)))
            '''
            # 计算检测框iou大于阈值的平均精度
            if iou.max() > threshold:
                recall += 1
                acc_pos += np.mean(all_boxes[index][iou > threshold, 4])
                acc_all += np.mean(all_boxes[index][:, 4])
                precision += len(all_boxes[index][iou > threshold, 4]) / len(all_boxes[index][:, 4])
            else:
                hard_samples.append(image_path)

        precision /= recall
        acc_pos /= recall
        acc_all /= recall
        recall /= self.test_loader.size
        print('IoU threshold={}:'.format(threshold), 'precision={},'.format(precision),
              ' acc-pos={},'.format(acc_pos), 'acc-all={}'.format(acc_all), 'recall={}'.format(recall))
        return precision, acc_pos, acc_all, recall