def ctpn(sess, net, image_name): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) new_scores = scores[:, np.newaxis] keep_inds = np.where(new_scores > TextLineCfg.TEXT_PROPOSALS_MIN_SCORE)[0] boxes, new_scores = boxes[keep_inds], new_scores[keep_inds] sorted_indices = np.argsort(new_scores.ravel())[::-1] boxes, new_scores = boxes[sorted_indices], new_scores[sorted_indices] keep_inds = nms(np.hstack((boxes, new_scores)), TextLineCfg.TEXT_PROPOSALS_NMS_THRESH) boxes, new_scores = boxes[keep_inds], new_scores[keep_inds] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.figure(figsize=(10, 14)) for key, box in enumerate(boxes): img_inside = img.copy() img_inside = cv2.rectangle(img_inside, (box[0], box[1]), (box[2], box[3]), color=(255, 0, 0), thickness=2) plt.imshow(img_inside) plt.title('Scores: {0}'.format(scores[key])) plt.savefig('./data/fig/fig_{0}.jpg'.format(key))
def ctpn(sess, net, image_name, save_path1, save_path2): timer = Timer() timer.tic() #读取图片 img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) #灰度化处理 #img2 = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY) #img2 = cv2.cvtColor(img2,cv2.COLOR_GRAY2RGB) # base_name = im_name.split('\\')[-1] # cv2.imwrite(os.path.join("data/results2", base_name), img2) scores, boxes = test_ctpn(sess, net, img) #后处理过程,detect包含过滤和合并 textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes2(img, boxes, image_name, save_path2, scale) draw_boxes(img, boxes, image_name, save_path1, scale) #后处理过程,detect2只过滤小文本框 # textdetector = TextDetector() # boxes = textdetector.detect2(boxes, scores[:, np.newaxis], img.shape[:2]) # draw_boxes3(img, boxes,image_name, scale) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
def test_net(sess, net, imdb, weights_filename): timer = Timer() timer.tic() np.random.seed(cfg.RNG_SEED) """Test a Fast R-CNN network on an image database.""" num_images = len(imdb.image_index) output_dir = get_output_dir(imdb, weights_filename) # timers _t = {'im_detect': Timer(), 'misc': Timer()} # all_boxes = [] all_boxes = [[[] for _ in range(imdb.num_classes)] for _ in range(num_images)] print(all_boxes) for i in range(num_images): print('***********', imdb.image_path_at(i)) img = cv2.imread(imdb.image_path_at(i)) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0])) boxes = check_unreasonable_box(boxes, scale) all_boxes[i][1] += boxes det_file = os.path.join(output_dir, 'detections.pkl') with open(det_file, 'wb') as f: pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) imdb.evaluate_detections(all_boxes, output_dir) timer.toc()
def ctpn(sess, net, image_name, boxlabel): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) img = draw_boxes(img, image_name, boxes, scale, None) boxlabel2 = np.transpose( np.array([ boxlabel[:, 0], boxlabel[:, 1], boxlabel[:, 2], boxlabel[:, 1], boxlabel[:, 0], boxlabel[:, 3], boxlabel[:, 2], boxlabel[:, 3], np.ones(len(boxlabel)) ])) draw_boxes(img, image_name, boxlabel2, 1, (0, 0, 0)) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0])) boxes = boxes / scale return boxes
def ctpn(sess, net, image_name): global true_text, true_non_text, false_text, false_non_text base_name = image_name.split('/')[-1] label_name = image_name.split('/')[-2] img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) print(len(boxes)) with open('boxes.txt', 'w') as f: f.write(str(len(boxes))) if len(boxes) > 0: if (label_name == 'non_text'): false_non_text += 1 else: true_text += 1 cv2.imwrite(os.path.join('data/results/text', base_name), img) else: if (label_name == 'text'): false_text += 1 else: true_non_text += 1 cv2.imwrite(os.path.join('data/results/non_text', base_name), img)
def get_model(cls): if cls.sess == None: cfg_from_file('ctpn/text.yml') # init session config = tf.ConfigProto(allow_soft_placement=True) cls.sess = tf.Session(config=config) # load network cls.net = get_network("VGGnet_test") # load model print(('Loading network {:s}... '.format("VGGnet_test")), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path) print('Restoring from {}...'.format( ckpt.model_checkpoint_path), end=' ') saver.restore(cls.sess, ckpt.model_checkpoint_path) print('done') except: raise 'Check your pretrained {:s}'.format( ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(cls.sess, cls.net, im)
def detect(self, image_path): if self.session is None: self.load() regions = [] img = cv2.imread(image_path) old_h, old_w, channels = img.shape img, scale = self.resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) new_h, new_w, channels = img.shape mul_h, mul_w = float(old_h) / float(new_h), float(old_w) / float(new_w) scores, boxes = test_ctpn(self.session, self.net, img) boxes = self.textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) for box in boxes: left, top = int(box[0]), int(box[1]) right, bottom = int(box[6]), int(box[7]) score = float(box[8]) left, top, right, bottom = int(left * mul_w), int( top * mul_h), int(right * mul_w), int(bottom * mul_h) r = { 'score': float(score), 'y': top, 'x': left, 'w': right - left, 'h': bottom - top, } regions.append(r) return regions
def ctpnSource(im_name): if os.path.exists("data/results/"): shutil.rmtree("data/results/") #目录下存在多级子目录,递归删除 os.makedirs("data/results/") cfg_from_file('ctpn/text.yml') # init session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # load network net = get_network("VGGnet_test") # load model # print(('Loading network {:s}... '.format("VGGnet_test")), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path) # print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ') saver.restore(sess, ckpt.model_checkpoint_path) # print('done') except: raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(sess, net, im) # im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ # glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) # for im_name in im_names: print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') print(('Demo for {:s}'.format(im_name))) ctpn(sess, net, im_name)
def detectorload(): cfg_from_file('TextDetection/ctpn/text.yml') # init session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # load network net = get_network("VGGnet_test") # load model print(('Loading network {:s}... '.format("VGGnet_test")), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state('TextDetection/checkpoints/') print( 'Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ') saver.restore(sess, ckpt.model_checkpoint_path) print('done') except: raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(sess, net, im) return net, sess net, sess = detectorload() ctpn(sess, net, im_name)
def ctpn(sess, net, image_name): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) #将OPENCV图像转换为PIL图像, pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) #求图片清晰度 imageVar = cv2.Laplacian(img, cv2.CV_64F).var() if imageVar <= 5000: pil_img = ImageEnhance.Sharpness(pil_img).enhance(3.0) #将PIL图像转换为opencv图像 img = cv2.cvtColor(np.asarray(pil_img), cv2.COLOR_RGB2BGR) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes(img, image_name, boxes, scale) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
def ctpn(sess, net, image_name): timer = Timer() timer.tic() img = cv2.imread(image_name) height, width = img.shape[:2] img = img[int(2 * height / 3.0):height, :] img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) # for box in boxes: # color = (0, 255, 0) # cv2.line(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[1])), color, 2) # cv2.line(img, (int(box[0]), int(box[1])), (int(box[0]), int(box[3])), color, 2) # cv2.line(img, (int(box[2]), int(box[1])), (int(box[2]), int(box[3])), color, 2) # cv2.line(img, (int(box[0]), int(box[3])), (int(box[2]), int(box[3])), color, 2) # base_name = image_name.split('/')[-1] # cv2.imwrite("data/results/test_"+base_name, img) # draw_boxes(img, image_name, boxes, scale) # print(boxes) # assert 0 textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes(img, image_name, boxes, scale) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
def ctpn_area(sess, net, image_name, dst, draw_img=False, show_area=False, area_min=-0.1, area_max=1.1): #timer = Timer() #timer.tic() img = cv2.imread(image_name) if img is None: return 0.0 img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) ret = compute_area(img, image_name, boxes, scale, dst, draw_img=draw_img, show_area=show_area, area_min=area_min, area_max=area_max) #timer.toc() #print(('Detection took {:.3f}s for ' # '{:d} object proposals').format(timer.total_time, boxes.shape[0])) return ret
def ctpn(sess, net, image_path): timer = Timer() timer.tic() img = cv2.imread(image_path) img_name = image_path.split('/')[-1] # 将图像进行resize并返回其缩放大小 img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) # 送入网络得到1000个得分,1000个bbox cls, scores, boxes = test_ctpn(sess, net, img) print('cls, scores, boxes', cls.shape, scores.shape, boxes.shape) # img_re = img # for i in range(np.shape(boxes)[0]): # if cls[i] == 1: # color = (255, 0, 0) # else: # color = (0, 255, 0) # cv2.rectangle(img_re, (boxes[i][0],boxes[i][1]),(boxes[i][2],boxes[i][3]),color,1) # cv2.imwrite(os.path.join('./data/proposal_res', img_name), img_re) handwritten_filter = np.where(cls == 1)[0] handwritten_scores = scores[handwritten_filter] handwritten_boxes = boxes[handwritten_filter, :] print_filter = np.where(cls == 2)[0] print_scores = scores[print_filter] print_boxes = boxes[print_filter, :] handwritten_detector = TextDetector() handwritten_detector = TextDetector() print('print_filter', np.array(print_filter).shape) print('handwritten_boxes, handwritten_scores', handwritten_boxes.shape, handwritten_scores[:, np.newaxis].shape) filted_handwritten_boxes = handwritten_detector.detect( handwritten_boxes, handwritten_scores[:, np.newaxis], img.shape[:2]) filted_print_boxes = handwritten_detector.detect( print_boxes, print_scores[:, np.newaxis], img.shape[:2]) # boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes(img, filted_handwritten_boxes, (255, 0, 0)) draw_boxes(img, filted_print_boxes, (0, 255, 0)) img = cv2.resize(img, None, None, fx=1.0 / scale, fy=1.0 / scale, interpolation=cv2.INTER_LINEAR) cv2.imwrite(os.path.join("data/results", img_name), img) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
def ctpn(img): """ text box detect """ scale, max_scale = Config.SCALE, Config.MAX_SCALE img, f = resize_im(img, scale=scale, max_scale=max_scale) scores, boxes = test_ctpn(sess, net, img) return scores, boxes, img
def ctpn(img): """ text box detect """ scale, max_scale = Config.SCALE,Config.MAX_SCALE img,f = resize_im(img,scale=scale,max_scale=max_scale) scores, boxes = test_ctpn(sess, net, img) return scores, boxes,img
def detect(src, dst, draw_img=False, show_area=False, area_min=-0.0, area_max=1.1): #if os.path.exists("data/results/"): # shutil.rmtree("data/results/") if not os.path.exists(dst): os.makedirs(dst) cfg_from_file('ctpn/text.yml') # init session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # load network net = get_network("VGGnet_test") # load model print(('Loading network {:s}... '.format("VGGnet_test")), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path) print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ') saver.restore(sess, ckpt.model_checkpoint_path) print('done') except: raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(sess, net, im) ''' im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) + \ glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpeg')) ''' im_names = glob.glob(os.path.join(src, '*.png')) \ + glob.glob(os.path.join(src, '*.PNG')) \ + glob.glob(os.path.join(src, '*.jpg')) \ + glob.glob(os.path.join(src, '*.JPG')) \ + glob.glob(os.path.join(src, '*.jpeg')) \ + glob.glob(os.path.join(src, '*.JPEG')) \ + glob.glob(os.path.join(src, '*.bmp')) \ + glob.glob(os.path.join(src, '*.BMP')) print("images:{}".format(len(im_names))) for im_name in im_names: print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') print(('Demo for {:s}'.format(im_name))) ctpn(sess, net, im_name, dst, draw_img=draw_img, show_area=show_area, area_min=area_min, area_max=area_max)
def text_detect(img): #ctpn scale, max_scale = Config.SCALE,Config.MAX_SCALE img,f = resize_im(img,scale=scale,max_scale=max_scale) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes,scores[:, np.newaxis],img.shape[:2]) text_recs,tmp = draw_boxes(img, boxes, caption='im_name', wait=True,is_display=False) return text_recs,tmp,img
def ctpn(img): """ text box detect """ scale, max_scale = Config.SCALE, Config.MAX_SCALE # 对图像进行resize,输出的图像长宽 print('original_size', img.shape) img, f = resize_im(img, scale=scale, max_scale=max_scale) print('resize', img.shape, f) scores, boxes = test_ctpn(sess, net, img) return scores, boxes, img
def ctpn(img): img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) return scores, boxes, img, scale
def ctpn(self, image_name): img = cv2.imread(image_name) img, scale = self.resize_im(img, scale=600, max_scale=1000) # 参考ctpn论文 scores, boxes = test_ctpn(self.sess, self.net, img) # ctpn识别实例 textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) min_y_sort_list, base_name = self.get_coordinates( img, image_name, boxes, scale) return min_y_sort_list, base_name
def ctpn(sess, net, image_name): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes(img, image_name, boxes, scale) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
def ctpn(sess, net, image_name): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) draw_boxes(img, image_name, boxes, scale) timer.toc()
def ctpn(sess, net, img): timer = Timer() timer.tic() img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) sort_index = np.argsort(boxes[:, -1])[::-1] boxes = boxes[sort_index] im, bboxes = draw_boxes(img, boxes, scale) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0])) return im, bboxes
def ctpn(sess, net, image_name, model): img = cv2.imread(image_name) #r = image_to_binary(img) #noise = np.ones(img.shape[:2],dtype="uint8") * 125 #img = cv2.merge((r+noise, r, noise)) img, scale = resize_im(img, scale=600, max_scale=1000) # 参考ctpn论文 print('ctpn', img.shape) scores, boxes = test_ctpn(sess, net, img) # ctpn识别实例 textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) get_coordinates(img, image_name, boxes, scale, model)
def predict(self, image_name): img = cv2.imread(image_name) img, scale = self.resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(self.sess, self.net, img) # print('scores', scores) # mask = scores > 0.9 # boxes = boxes[mask] # print('length of boxes', len(boxes)) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) return img, boxes, scale
def ctpn(cv_image): os.chdir(CTPN_DIR) with ctpn_sess.as_default(): img = cv_image img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(ctpn_sess, ctpn_net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) boxes[:, 0:8] /= scale os.chdir(ROOT_DIR) return boxes
def ctpn(sess, net, image_name): img = cv2.imread(image_name) im = check_img(img) timer = Timer() timer.tic() scores, boxes = test_ctpn(sess, net, im) timer.toc() CONF_THRESH = 0.9 NMS_THRESH = 0.3 dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] keep = np.where(dets[:, 4] >= 0.7)[0] dets = dets[keep, :] line = connect_proposal(dets[:, 0:4], dets[:, 4], im.shape) save_results(image_name, im, line, thresh=0.9)
def ctpn(img): timer = Timer() timer.tic() img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) timer.toc() #print("\n----------------------------------------------") #print(('Detection took {:.3f}s for ' # '{:d} object proposals').format(timer.total_time, boxes.shape[0])) return scores, boxes, img, scale, timer.total_time, boxes.shape[0]
def ctpn(sess, net, frame, draw): # timer = Timer() # timer.tic() img, scale = resize_im( frame, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, net, img) textdetector = TextDetector() boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) buf = img.copy() crop = crop_image(buf, boxes, scale) # timer.toc() if draw is 1: draw_boxes(img, boxes, scale) return crop
def local_result(list_img_path, save_dir, ctpn_path, base_net): save_dir = os.path.join(cfg.ROOT_DIR, 'data', save_dir) if os.path.exists(save_dir): shutil.rmtree(save_dir) os.makedirs(save_dir) yml_path = os.path.join(cfg.ROOT_DIR, 'ctpn/text.yml') cfg_from_file(yml_path) # init session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # load network training_flag = tf.placeholder(tf.bool) net = get_network(base_net, training_flag) # load model print(('Loading network {:s}... '.format(base_net)), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state(ctpn_path) print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ') saver.restore(sess, ckpt.model_checkpoint_path) print('done') except: raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(sess, training_flag, net, im) im_names = list_img_path # im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ # glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) xml_dir = '/data/kuaidi01/dataset_detect/VOC2007/Annotations' for im_name in im_names: print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') print(('Demo for {:s}'.format(im_name))) if im_name[-7:]=='164.jpg': continue xml_path = os.path.join(xml_dir, os.path.basename(im_name)[:-4]+'.xml') ctpn(sess,training_flag, net, im_name, save_dir, xml_path)
def ctpn(sess, training_flag, net, image_name, save_all_dir): timer = Timer() timer.tic() img = cv2.imread(image_name) img, scale = resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE) scores, boxes = test_ctpn(sess, training_flag, net, img) textdetector = TextDetector() boxes1, boxes2, boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) # img1 = img.copy() # draw_middle_boxes(img1, boxes1, scale) # img2 = img.copy() # draw_middle_boxes(img2, boxes2, scale) draw_boxes(img, image_name, boxes, scale, save_all_dir) timer.toc() print(('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]))
# init session config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # load network net = get_network("VGGnet_test") # load model print(('Loading network {:s}... '.format("VGGnet_test")), end=' ') saver = tf.train.Saver() try: ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path) print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ') saver.restore(sess, ckpt.model_checkpoint_path) print('done') except: raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in range(2): _, _ = test_ctpn(sess, net, im) im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) for im_name in im_names: print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') print(('Demo for {:s}'.format(im_name))) ctpn(sess, net, im_name)