コード例 #1
0
ファイル: demo2.py プロジェクト: zzfancitizen/ctpn_win
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    new_scores = scores[:, np.newaxis]

    keep_inds = np.where(new_scores > TextLineCfg.TEXT_PROPOSALS_MIN_SCORE)[0]
    boxes, new_scores = boxes[keep_inds], new_scores[keep_inds]

    sorted_indices = np.argsort(new_scores.ravel())[::-1]
    boxes, new_scores = boxes[sorted_indices], new_scores[sorted_indices]

    keep_inds = nms(np.hstack((boxes, new_scores)),
                    TextLineCfg.TEXT_PROPOSALS_NMS_THRESH)
    boxes, new_scores = boxes[keep_inds], new_scores[keep_inds]

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(10, 14))

    for key, box in enumerate(boxes):
        img_inside = img.copy()
        img_inside = cv2.rectangle(img_inside, (box[0], box[1]),
                                   (box[2], box[3]),
                                   color=(255, 0, 0),
                                   thickness=2)
        plt.imshow(img_inside)
        plt.title('Scores: {0}'.format(scores[key]))
        plt.savefig('./data/fig/fig_{0}.jpg'.format(key))
コード例 #2
0
ファイル: demo.py プロジェクト: smartcai/BankCardOCR
def ctpn(sess, net, image_name, save_path1, save_path2):
    timer = Timer()
    timer.tic()

    #读取图片
    img = cv2.imread(image_name)
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    #灰度化处理
    #img2 = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
    #img2 = cv2.cvtColor(img2,cv2.COLOR_GRAY2RGB)
    #     base_name = im_name.split('\\')[-1]
    #     cv2.imwrite(os.path.join("data/results2", base_name), img2)

    scores, boxes = test_ctpn(sess, net, img)

    #后处理过程,detect包含过滤和合并
    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes2(img, boxes, image_name, save_path2, scale)
    draw_boxes(img, boxes, image_name, save_path1, scale)

    #后处理过程,detect2只过滤小文本框
    #     textdetector = TextDetector()
    #     boxes = textdetector.detect2(boxes, scores[:, np.newaxis], img.shape[:2])
    #     draw_boxes3(img, boxes,image_name, scale)

    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #3
0
def test_net(sess, net, imdb, weights_filename):
    timer = Timer()
    timer.tic()
    np.random.seed(cfg.RNG_SEED)
    """Test a Fast R-CNN network on an image database."""
    num_images = len(imdb.image_index)
    output_dir = get_output_dir(imdb, weights_filename)
    # timers
    _t = {'im_detect': Timer(), 'misc': Timer()}
    # all_boxes = []
    all_boxes = [[[] for _ in range(imdb.num_classes)]
                 for _ in range(num_images)]
    print(all_boxes)
    for i in range(num_images):
        print('***********', imdb.image_path_at(i))
        img = cv2.imread(imdb.image_path_at(i))
        img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
        scores, boxes = test_ctpn(sess, net, img)
        textdetector = TextDetector()
        boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
        print(('Detection took {:.3f}s for '
               '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
        boxes = check_unreasonable_box(boxes, scale)
        all_boxes[i][1] += boxes
    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    imdb.evaluate_detections(all_boxes, output_dir)
    timer.toc()
コード例 #4
0
ファイル: demo.py プロジェクト: Skii3/temp
def ctpn(sess, net, image_name, boxlabel):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)

    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    img = draw_boxes(img, image_name, boxes, scale, None)
    boxlabel2 = np.transpose(
        np.array([
            boxlabel[:, 0], boxlabel[:, 1], boxlabel[:, 2], boxlabel[:, 1],
            boxlabel[:, 0], boxlabel[:, 3], boxlabel[:, 2], boxlabel[:, 3],
            np.ones(len(boxlabel))
        ]))
    draw_boxes(img, image_name, boxlabel2, 1, (0, 0, 0))
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
    boxes = boxes / scale
    return boxes
コード例 #5
0
def ctpn(sess, net, image_name):
    global true_text, true_non_text, false_text, false_non_text
    base_name = image_name.split('/')[-1]
    label_name = image_name.split('/')[-2]
    img = cv2.imread(image_name)
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    print(len(boxes))
    with open('boxes.txt', 'w') as f:
        f.write(str(len(boxes)))
    if len(boxes) > 0:
        if (label_name == 'non_text'):
            false_non_text += 1
        else:
            true_text += 1
            cv2.imwrite(os.path.join('data/results/text', base_name), img)
    else:
        if (label_name == 'text'):
            false_text += 1
        else:
            true_non_text += 1
            cv2.imwrite(os.path.join('data/results/non_text', base_name), img)
コード例 #6
0
    def get_model(cls):
        if cls.sess == None:
            cfg_from_file('ctpn/text.yml')

            # init session
            config = tf.ConfigProto(allow_soft_placement=True)
            cls.sess = tf.Session(config=config)
            # load network
            cls.net = get_network("VGGnet_test")
            # load model
            print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
            saver = tf.train.Saver()

            try:
                ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path)
                print('Restoring from {}...'.format(
                    ckpt.model_checkpoint_path),
                      end=' ')
                saver.restore(cls.sess, ckpt.model_checkpoint_path)
                print('done')
            except:
                raise 'Check your pretrained {:s}'.format(
                    ckpt.model_checkpoint_path)

            im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
            for i in range(2):
                _, _ = test_ctpn(cls.sess, cls.net, im)
コード例 #7
0
 def detect(self, image_path):
     if self.session is None:
         self.load()
     regions = []
     img = cv2.imread(image_path)
     old_h, old_w, channels = img.shape
     img, scale = self.resize_im(img,
                                 scale=TextLineCfg.SCALE,
                                 max_scale=TextLineCfg.MAX_SCALE)
     new_h, new_w, channels = img.shape
     mul_h, mul_w = float(old_h) / float(new_h), float(old_w) / float(new_w)
     scores, boxes = test_ctpn(self.session, self.net, img)
     boxes = self.textdetector.detect(boxes, scores[:, np.newaxis],
                                      img.shape[:2])
     for box in boxes:
         left, top = int(box[0]), int(box[1])
         right, bottom = int(box[6]), int(box[7])
         score = float(box[8])
         left, top, right, bottom = int(left * mul_w), int(
             top * mul_h), int(right * mul_w), int(bottom * mul_h)
         r = {
             'score': float(score),
             'y': top,
             'x': left,
             'w': right - left,
             'h': bottom - top,
         }
         regions.append(r)
     return regions
コード例 #8
0
def ctpnSource(im_name):
    if os.path.exists("data/results/"):
        shutil.rmtree("data/results/")  #目录下存在多级子目录,递归删除
    os.makedirs("data/results/")

    cfg_from_file('ctpn/text.yml')

    # init session
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    # load network
    net = get_network("VGGnet_test")
    # load model
    #    print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
    saver = tf.train.Saver()

    try:
        ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path)
        #        print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ')
        saver.restore(sess, ckpt.model_checkpoint_path)
#        print('done')
    except:
        raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)

    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = test_ctpn(sess, net, im)

#    im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
#               glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg'))

#    for im_name in im_names:
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print(('Demo for {:s}'.format(im_name)))
    ctpn(sess, net, im_name)
コード例 #9
0
def detectorload():
    cfg_from_file('TextDetection/ctpn/text.yml')

    # init session
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    # load network
    net = get_network("VGGnet_test")
    # load model
    print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
    saver = tf.train.Saver()
    try:
        ckpt = tf.train.get_checkpoint_state('TextDetection/checkpoints/')
        print(
            'Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ')
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('done')
    except:
        raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)

    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = test_ctpn(sess, net, im)
    return net, sess

    net, sess = detectorload()
    ctpn(sess, net, im_name)
コード例 #10
0
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)

    #将OPENCV图像转换为PIL图像,
    pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    #求图片清晰度
    imageVar = cv2.Laplacian(img, cv2.CV_64F).var()
    if imageVar <= 5000:
        pil_img = ImageEnhance.Sharpness(pil_img).enhance(3.0)
    #将PIL图像转换为opencv图像
    img = cv2.cvtColor(np.asarray(pil_img), cv2.COLOR_RGB2BGR)

    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, image_name, boxes, scale)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #11
0
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    height, width = img.shape[:2]
    img = img[int(2 * height / 3.0):height, :]
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)
    # for box in boxes:
    #     color = (0, 255, 0)
    #     cv2.line(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[1])), color, 2)
    #     cv2.line(img, (int(box[0]), int(box[1])), (int(box[0]), int(box[3])), color, 2)
    #     cv2.line(img, (int(box[2]), int(box[1])), (int(box[2]), int(box[3])), color, 2)
    #     cv2.line(img, (int(box[0]), int(box[3])), (int(box[2]), int(box[3])), color, 2)
    # base_name = image_name.split('/')[-1]
    # cv2.imwrite("data/results/test_"+base_name, img)
    # draw_boxes(img, image_name, boxes, scale)
    # print(boxes)
    # assert 0
    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, image_name, boxes, scale)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #12
0
def ctpn_area(sess,
              net,
              image_name,
              dst,
              draw_img=False,
              show_area=False,
              area_min=-0.1,
              area_max=1.1):
    #timer = Timer()
    #timer.tic()

    img = cv2.imread(image_name)
    if img is None:
        return 0.0
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    ret = compute_area(img,
                       image_name,
                       boxes,
                       scale,
                       dst,
                       draw_img=draw_img,
                       show_area=show_area,
                       area_min=area_min,
                       area_max=area_max)
    #timer.toc()
    #print(('Detection took {:.3f}s for '
    #       '{:d} object proposals').format(timer.total_time, boxes.shape[0]))

    return ret
コード例 #13
0
def ctpn(sess, net, image_path):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_path)
    img_name = image_path.split('/')[-1]
    # 将图像进行resize并返回其缩放大小
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    # 送入网络得到1000个得分,1000个bbox
    cls, scores, boxes = test_ctpn(sess, net, img)

    print('cls, scores, boxes', cls.shape, scores.shape, boxes.shape)

    # img_re = img
    # for i in range(np.shape(boxes)[0]):
    #     if cls[i] == 1:
    #         color = (255, 0, 0)
    #     else:
    #         color = (0, 255, 0)
    #     cv2.rectangle(img_re, (boxes[i][0],boxes[i][1]),(boxes[i][2],boxes[i][3]),color,1)
    # cv2.imwrite(os.path.join('./data/proposal_res', img_name), img_re)

    handwritten_filter = np.where(cls == 1)[0]
    handwritten_scores = scores[handwritten_filter]
    handwritten_boxes = boxes[handwritten_filter, :]

    print_filter = np.where(cls == 2)[0]
    print_scores = scores[print_filter]
    print_boxes = boxes[print_filter, :]

    handwritten_detector = TextDetector()
    handwritten_detector = TextDetector()

    print('print_filter', np.array(print_filter).shape)
    print('handwritten_boxes, handwritten_scores', handwritten_boxes.shape,
          handwritten_scores[:, np.newaxis].shape)

    filted_handwritten_boxes = handwritten_detector.detect(
        handwritten_boxes, handwritten_scores[:, np.newaxis], img.shape[:2])
    filted_print_boxes = handwritten_detector.detect(
        print_boxes, print_scores[:, np.newaxis], img.shape[:2])

    # boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, filted_handwritten_boxes, (255, 0, 0))
    draw_boxes(img, filted_print_boxes, (0, 255, 0))

    img = cv2.resize(img,
                     None,
                     None,
                     fx=1.0 / scale,
                     fy=1.0 / scale,
                     interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(os.path.join("data/results", img_name), img)

    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #14
0
ファイル: model.py プロジェクト: laugha/CTPN_OCR
def ctpn(img):
    """
    text box detect
    """
    scale, max_scale = Config.SCALE, Config.MAX_SCALE
    img, f = resize_im(img, scale=scale, max_scale=max_scale)
    scores, boxes = test_ctpn(sess, net, img)
    return scores, boxes, img
コード例 #15
0
ファイル: model.py プロジェクト: dcrmg/chinese-ocr
def ctpn(img):
    """
    text box detect
    """
    scale, max_scale = Config.SCALE,Config.MAX_SCALE
    img,f = resize_im(img,scale=scale,max_scale=max_scale)
    scores, boxes = test_ctpn(sess, net, img)
    return scores, boxes,img
コード例 #16
0
def detect(src,
           dst,
           draw_img=False,
           show_area=False,
           area_min=-0.0,
           area_max=1.1):
    #if os.path.exists("data/results/"):
    #    shutil.rmtree("data/results/")
    if not os.path.exists(dst):
        os.makedirs(dst)
    cfg_from_file('ctpn/text.yml')

    # init session
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    # load network
    net = get_network("VGGnet_test")
    # load model
    print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
    saver = tf.train.Saver()

    try:
        ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path)
        print('Restoring from {}...'.format(ckpt.model_checkpoint_path),
              end=' ')
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('done')
    except:
        raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)

    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = test_ctpn(sess, net, im)
    '''
    im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
               glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) + \
               glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpeg'))
    '''
    im_names = glob.glob(os.path.join(src, '*.png')) \
               + glob.glob(os.path.join(src, '*.PNG')) \
               + glob.glob(os.path.join(src, '*.jpg')) \
               + glob.glob(os.path.join(src, '*.JPG')) \
               + glob.glob(os.path.join(src, '*.jpeg')) \
               + glob.glob(os.path.join(src, '*.JPEG')) \
               + glob.glob(os.path.join(src, '*.bmp')) \
               + glob.glob(os.path.join(src, '*.BMP'))
    print("images:{}".format(len(im_names)))
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print(('Demo for {:s}'.format(im_name)))
        ctpn(sess,
             net,
             im_name,
             dst,
             draw_img=draw_img,
             show_area=show_area,
             area_min=area_min,
             area_max=area_max)
コード例 #17
0
ファイル: web_test.py プロジェクト: lihow/chinese-ocr-win
def text_detect(img):
    #ctpn
    scale, max_scale = Config.SCALE,Config.MAX_SCALE
    img,f = resize_im(img,scale=scale,max_scale=max_scale)
    scores, boxes = test_ctpn(sess, net, img)
    textdetector  = TextDetector()
    boxes = textdetector.detect(boxes,scores[:, np.newaxis],img.shape[:2])
    text_recs,tmp = draw_boxes(img, boxes, caption='im_name', wait=True,is_display=False)
    return text_recs,tmp,img
コード例 #18
0
ファイル: model.py プロジェクト: haswelliris/CHINESE-OCR-1
def ctpn(img):
    """
    text box detect
    """
    scale, max_scale = Config.SCALE, Config.MAX_SCALE
    # 对图像进行resize,输出的图像长宽
    print('original_size', img.shape)
    img, f = resize_im(img, scale=scale, max_scale=max_scale)
    print('resize', img.shape, f)
    scores, boxes = test_ctpn(sess, net, img)
    return scores, boxes, img
コード例 #19
0
def ctpn(img):

    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)

    scores, boxes = test_ctpn(sess, net, img)
    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])

    return scores, boxes, img, scale
コード例 #20
0
    def ctpn(self, image_name):
        img = cv2.imread(image_name)
        img, scale = self.resize_im(img, scale=600, max_scale=1000)  # 参考ctpn论文
        scores, boxes = test_ctpn(self.sess, self.net, img)
        # ctpn识别实例
        textdetector = TextDetector()
        boxes = textdetector.detect(boxes, scores[:, np.newaxis],
                                    img.shape[:2])
        min_y_sort_list, base_name = self.get_coordinates(
            img, image_name, boxes, scale)

        return min_y_sort_list, base_name
コード例 #21
0
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()
    img = cv2.imread(image_name)
    img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, image_name, boxes, scale)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #22
0
ファイル: demo.py プロジェクト: wl5650/Bankcard_OCR
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, image_name, boxes, scale)
    timer.toc()
コード例 #23
0
ファイル: infer.py プロジェクト: ForeversKing/sigin_project
def ctpn(sess, net, img):
    timer = Timer()
    timer.tic()
    img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)
    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    sort_index = np.argsort(boxes[:, -1])[::-1]
    boxes = boxes[sort_index]
    im, bboxes = draw_boxes(img, boxes, scale)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
    return im, bboxes
コード例 #24
0
def ctpn(sess, net, image_name, model):
    img = cv2.imread(image_name)

    #r = image_to_binary(img)
    #noise = np.ones(img.shape[:2],dtype="uint8") * 125
    #img = cv2.merge((r+noise, r, noise))
    
    img, scale = resize_im(img, scale=600, max_scale=1000) # 参考ctpn论文
    print('ctpn', img.shape)
    scores, boxes = test_ctpn(sess, net, img)
    # ctpn识别实例
    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    get_coordinates(img, image_name, boxes, scale, model)
コード例 #25
0
ファイル: demo.py プロジェクト: lreve915/text-detection-ctpn
def ctpn(sess, net, image_name):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    draw_boxes(img, image_name, boxes, scale)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #26
0
 def predict(self, image_name):
     img = cv2.imread(image_name)
     img, scale = self.resize_im(img,
                                 scale=TextLineCfg.SCALE,
                                 max_scale=TextLineCfg.MAX_SCALE)
     scores, boxes = test_ctpn(self.sess, self.net, img)
     # print('scores', scores)
     # mask = scores > 0.9
     # boxes = boxes[mask]
     # print('length of boxes', len(boxes))
     textdetector = TextDetector()
     boxes = textdetector.detect(boxes, scores[:, np.newaxis],
                                 img.shape[:2])
     return img, boxes, scale
コード例 #27
0
def ctpn(cv_image):
    os.chdir(CTPN_DIR)
    with ctpn_sess.as_default():
        img = cv_image
        img, scale = resize_im(img,
                               scale=TextLineCfg.SCALE,
                               max_scale=TextLineCfg.MAX_SCALE)
        scores, boxes = test_ctpn(ctpn_sess, ctpn_net, img)

        textdetector = TextDetector()
        boxes = textdetector.detect(boxes, scores[:, np.newaxis],
                                    img.shape[:2])
        boxes[:, 0:8] /= scale

    os.chdir(ROOT_DIR)
    return boxes
コード例 #28
0
def ctpn(sess, net, image_name):
    img = cv2.imread(image_name)
    im = check_img(img)
    timer = Timer()
    timer.tic()
    scores, boxes = test_ctpn(sess, net, im)
    timer.toc()
    CONF_THRESH = 0.9
    NMS_THRESH = 0.3
    dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32)
    keep = nms(dets, NMS_THRESH)
    dets = dets[keep, :]

    keep = np.where(dets[:, 4] >= 0.7)[0]
    dets = dets[keep, :]
    line = connect_proposal(dets[:, 0:4], dets[:, 4], im.shape)
    save_results(image_name, im, line, thresh=0.9)
コード例 #29
0
def ctpn(img):
    timer = Timer()
    timer.tic()

    img, scale = resize_im(img,
                           scale=TextLineCfg.SCALE,
                           max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    timer.toc()
    #print("\n----------------------------------------------")
    #print(('Detection took {:.3f}s for '
    #      '{:d} object proposals').format(timer.total_time, boxes.shape[0]))

    return scores, boxes, img, scale, timer.total_time, boxes.shape[0]
コード例 #30
0
def ctpn(sess, net, frame, draw):
    # timer = Timer()
    # timer.tic()

    img, scale = resize_im(
        frame, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, net, img)

    textdetector = TextDetector()
    boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    buf = img.copy()
    crop = crop_image(buf, boxes, scale)

    # timer.toc()
    if draw is 1:
        draw_boxes(img, boxes, scale)
    return crop
コード例 #31
0
def local_result(list_img_path, save_dir, ctpn_path, base_net):
    save_dir = os.path.join(cfg.ROOT_DIR, 'data', save_dir)
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)
    os.makedirs(save_dir)

    yml_path = os.path.join(cfg.ROOT_DIR, 'ctpn/text.yml')
    cfg_from_file(yml_path)


    # init session
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    # load network
    training_flag = tf.placeholder(tf.bool)
    net = get_network(base_net, training_flag)
    # load model
    print(('Loading network {:s}... '.format(base_net)), end=' ')
    saver = tf.train.Saver()

    try:
        ckpt = tf.train.get_checkpoint_state(ctpn_path)
        print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ')
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('done')
    except:
        raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)

    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = test_ctpn(sess, training_flag, net, im)

    im_names = list_img_path

    # im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
    #            glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg'))

    xml_dir = '/data/kuaidi01/dataset_detect/VOC2007/Annotations'
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print(('Demo for {:s}'.format(im_name)))
        if im_name[-7:]=='164.jpg':
            continue
        xml_path = os.path.join(xml_dir, os.path.basename(im_name)[:-4]+'.xml')
        ctpn(sess,training_flag,  net, im_name, save_dir, xml_path)
コード例 #32
0
def ctpn(sess, training_flag, net, image_name, save_all_dir):
    timer = Timer()
    timer.tic()

    img = cv2.imread(image_name)
    img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
    scores, boxes = test_ctpn(sess, training_flag, net, img)

    textdetector = TextDetector()
    boxes1, boxes2, boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2])
    # img1 = img.copy()
    # draw_middle_boxes(img1, boxes1, scale)
    # img2 = img.copy()
    # draw_middle_boxes(img2, boxes2, scale)
    draw_boxes(img, image_name, boxes, scale, save_all_dir)
    timer.toc()
    print(('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
コード例 #33
0
ファイル: demo.py プロジェクト: lreve915/text-detection-ctpn
    # init session
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    # load network
    net = get_network("VGGnet_test")
    # load model
    print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
    saver = tf.train.Saver()

    try:
        ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path)
        print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ')
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('done')
    except:
        raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)

    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in range(2):
        _, _ = test_ctpn(sess, net, im)

    im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
               glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg'))

    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print(('Demo for {:s}'.format(im_name)))
        ctpn(sess, net, im_name)