def run(self, image, threshold=0.7, do_nms=True, show_result=True, nms_threshold=0.3):
        """Public function to run the DigitSpotter.

        Parameters
        ----------
        image : str
            filename of the test image
            
        Returns
        ----------
        bbs : ndarray, shape of (N, 4)
            detected bounding box. (y1, y2, x1, x2) ordered.
        
        probs : ndarray, shape of (N,)
            evaluated score for the DigitSpotter and test images on average precision. 
    
        Examples
        --------
        """
        
        # 1. Get candidate patches
        candidate_regions = self._region_proposer.detect(image)
        patches = candidate_regions.get_patches(dst_size=self._cls.input_shape)
        
        # 3. Run pre-trained classifier
        probs = self._cls.predict_proba(patches)[:, 1]
    
        # 4. Thresholding
        bbs, patches, probs = self._get_thresholded_boxes(candidate_regions.get_boxes(), patches, probs, threshold)
    
        # 5. non-maxima-suppression
        if do_nms and len(bbs) != 0:
            bbs, patches, probs = NonMaxSuppressor().run(bbs, patches, probs, nms_threshold)
        
        if len(patches) > 0:
            probs_ = self._recognizer.predict_proba(patches)
            y_pred = probs_.argmax(axis=1)
        
        if show_result:
            for i, bb in enumerate(bbs):
                
                # todo : show module 정리
                image = show.draw_box(image, bb, 2)
                
                y1, y2, x1, x2 = bb
                msg = "{0}".format(y_pred[i])
                cv2.putText(image, msg, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), thickness=2)
                
            cv2.imshow("MSER + CNN", image)
            cv2.waitKey(0)
        
        return bbs, probs
示例#2
0
    def run(self,
            image,
            threshold=0.7,
            do_nms=True,
            show_result=True,
            nms_threshold=0.3):
        """Public function to run the DigitSpotter.

        Parameters
        ----------
        image : str
            filename of the test image
            
        Returns
        ----------
        bbs : ndarray, shape of (N, 4)
            detected bounding box. (y1, y2, x1, x2) ordered.
        
        probs : ndarray, shape of (N,)
            evaluated score for the DigitSpotter and test images on average precision. 
    
        Examples
        --------
        """

        # 1. Get candidate patches
        candidate_regions = self._region_proposer.detect(image)
        patches = candidate_regions.get_patches(dst_size=self._cls.input_shape)

        # 3. Run pre-trained classifier
        probs = self._cls.predict_proba(patches)[:, 1]

        # 4. Thresholding
        bbs, patches, probs = self._get_thresholded_boxes(
            candidate_regions.get_boxes(), patches, probs, threshold)

        # 5. non-maxima-suppression
        if do_nms and len(bbs) != 0:
            bbs, patches, probs = NonMaxSuppressor().run(
                bbs, patches, probs, nms_threshold)

        if len(patches) > 0:
            probs_ = self._recognizer.predict_proba(patches)
            y_pred = probs_.argmax(axis=1)

            max_area = 0
            area_list = []
            for i, bb in enumerate(bbs):
                y1, y2, x1, x2 = bb
                bb_area = (y2 - y1) * (x2 - x1)
                area_list.append(bb_area)
                if bb_area > max_area:
                    max_area = bb_area

            index = [
                a for a in range(len(area_list))
                if area_list[a] >= max_area * 0.97
            ]

            if len(index) == 1:
                return y_pred[index[0]]
            elif len(index) == 2:
                if bbs[index[0]][2] > bbs[index[1]][2]:
                    return y_pred[index[1]] * 10 + y_pred[index[0]]
                else:
                    return y_pred[index[0]] * 10 + y_pred[index[1]]
            else:
                return y_pred

        if show_result:
            for i, bb in enumerate(bbs):
                #if probs[i] > 0.99:
                # todo : show module 정리
                image = show.draw_box(image, bb, 2)

                y1, y2, x1, x2 = bb
                msg = "{0}\n{1}\n{2}".format(y_pred[i], bb, probs[i])

                cv2.putText(image,
                            msg, ((x1 + x2) // 2, (y1 + y2) // 2),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            0.5, (0, 0, 255),
                            thickness=1)

            cv2.imshow("MSER + CNN", image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()

        return bbs
示例#3
0
        else:
            i = 0

        print i
        if i == 0:
            image = img
        else:
            image = cv2.flip(img, -1)

        digit_dic = {}
        bbs, probs, y_preds = results[i]
        for bb, prob, y_pred in zip(bbs, probs, y_preds):
            if prob < threshold:
                continue

            image = show.draw_box(image, bb, 2)

            y1, y2, x1, x2 = bb
            msg = "{0}".format(y_pred)
            cv2.putText(image,
                        msg, (x1, y1),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1, (0, 0, 255),
                        thickness=2)

            digit_dic[x1] = y_pred

        sorted_list = sorted(digit_dic.items(), key=operator.itemgetter(0))
        digit_list = [item[1] for item in sorted_list]

        print 'Predict', os.path.basename(img_file), digit_list