def save(self, image_name, image, region_scores, affinity_scores, confidence_mask): boxes, polys = craft_utils.getDetBoxes(region_scores / 255, affinity_scores / 255, 0.7, 0.4, 0.4, False) boxes = np.array(boxes, np.int32) * 2 if len(boxes) > 0: np.clip(boxes[:, :, 0], 0, image.shape[1]) np.clip(boxes[:, :, 1], 0, image.shape[0]) for box in boxes: cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255)) target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores / 255) target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg( affinity_scores / 255) confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask) gt_scores = np.hstack([ target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color ]) confidence_mask_gray = np.hstack( [np.zeros_like(confidence_mask_gray), confidence_mask_gray]) output = np.concatenate([gt_scores, confidence_mask_gray], axis=0) output = np.hstack([image, output]) out_path = os.path.join( os.path.join(os.path.dirname(__file__) + '/output'), "%s_input.jpg" % image_name) print(out_path) if not os.path.exists(os.path.dirname(out_path)): os.mkdir(os.path.dirname(out_path)) cv2.imwrite(out_path, output)
def saveImage(imagename, image, bboxes, affinity_bboxes, region_scores, affinity_scores, confidence_mask): output_image = np.uint8(image.copy()) output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) if len(bboxes) > 0: affinity_bboxes = np.int32(affinity_bboxes) for i in range(affinity_bboxes.shape[0]): cv2.polylines(output_image, [np.reshape(affinity_bboxes[i], (-1, 1, 2))], True, (255, 0, 0)) for i in range(len(bboxes)): _bboxes = np.int32(bboxes[i]) for j in range(_bboxes.shape[0]): cv2.polylines(output_image, [np.reshape(_bboxes[j], (-1, 1, 2))], True, (0, 0, 255)) target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores / 255) target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg( affinity_scores / 255) heat_map = np.concatenate([ target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color ], axis=1) confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask) output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1) outpath = os.path.join('./output', imagename) if not os.path.exists(os.path.dirname(outpath)): os.mkdir(os.path.dirname(outpath)) cv2.imwrite(outpath, output)
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_AREA, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) return boxes, polys, ret_score_text
def get_bbox(image_batch, ratio_w, ratio_h): images_torch = torch.from_numpy(image_batch).unsqueeze(0).cuda().permute( 0, 3, 1, 2) with torch.no_grad(), torch.jit.optimized_execution(True): pred = model(images_torch) rs_tensor = pred[:, 0, :, :].cpu().numpy() as_tensor = pred[:, 1, :, :].cpu().numpy() render_img = rs_tensor[0].copy() render_img = np.hstack((render_img, as_tensor[0])) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.verbose: cv2.imshow('score', ret_score_text) ret = [] for (rs_img, as_img) in zip(rs_tensor, as_tensor): # Post-processing boxes, polys = craft_utils.getDetBoxes(rs_img, as_img, text_threshold, link_threshold, low_text, use_poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] frame_bboxes = [] polys = merge_bboxes(polys) for [tl, tr, br, bl] in polys: frame_bboxes.append({ 'x': int(tl[0]), 'y': int(tl[1]), 'width': int(tr[0] - tl[0]), 'height': int(br[1] - tr[1]) }) ret.append({'bboxes': frame_bboxes}) return ret[0]['bboxes']
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, image_path): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass y, _ = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() if args.debug: np.save( os.path.join( './debug', os.path.basename(image_path).split('.')[0] + '_score_text.npy'), score_text) np.save( os.path.join( './debug', os.path.basename(image_path).split('.')[0] + '_score_link.npy'), score_link) # Post-processing boxes = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text) boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, ret_score_text
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, ocr_type): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = x.unsqueeze(0) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass y, _ = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().detach().numpy() score_link = y[0, :, :, 1].cpu().detach().numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly, ocr_type) # coordinate adjustment boxes = utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = utils.adjustResultCoordinates(polys, ratio_w, ratio_h) if ocr_type == 'single_char': boxes = utils.cluster_sort(image.shape, boxes) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
def test_net(self, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = self.net(x) # make score and link map score_text = y[0,:,:,0].cpu().data.numpy() score_link = y[0,:,:,1].cpu().data.numpy() # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0,:,:,0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) # if show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
def generate_gt(net_pretrained, image, boxes, labels, args): region_gt = link_gt = np.zeros((args.canvas_size // 2, args.canvas_size // 2), dtype=np.float32) conf_map = np.zeros((args.canvas_size // 2, args.canvas_size // 2), dtype=np.float32) gaussian = generate_gaussian(500) for i, box in enumerate(boxes): # Crop bounding box region warped = transform_image(image, box) # Apply pretrained network score_text, target_ratio = gt_net(net_pretrained, warped, args) # render results (optional) render_img = imgproc.cvt2HeatmapImg(score_text.copy()) watershed = watershed_labeling(render_img) box_chr = chr_annotation(watershed) wordlen = len(labels[i]) sconf = float(wordlen - min(wordlen, abs(wordlen - len(box_chr)))) / float(wordlen) h, w = np.shape(score_text) if sconf < 0.5: box_chr = [] bw = w // wordlen for j in range(wordlen): box_adj = np.array([[j * bw, 0], [(j + 1) * bw, 0], [(j + 1) * bw, h], [j * bw, h]]) box_chr.append(box_adj) sconf = 0.5 box_aff = [] for k in range(len(box_chr) - 1): box_aff.append(get_affinity(box_chr[k], box_chr[k + 1])) conf_box = np.ones((h, w), dtype='float32') * sconf region_box = link_box = np.zeros((h, w), dtype='float32') for rbox in box_chr: region_box = restore(gaussian, region_box, rbox) for abox in box_aff: link_box = restore(gaussian, link_box, abox) box_adj = craft_utils.adjustResultCoordinates([np.float64(box)], target_ratio, target_ratio, 0.5)[0] region_gt = restore(region_box, region_gt, box_adj) link_gt = restore(link_box, link_gt, box_adj) conf_map = restore(conf_box, conf_map, box_adj) print(region_gt.shape, link_gt.shape, conf_map.shape) return region_gt, link_gt, conf_map
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly,filename,result_folder=result_folder): t0 = time.time() img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) #cv2.imwrite("test.jpg",x) print("###") x = tf.expand_dims(x,0) print(x.shape) # forward pass y, _ = net(x) # make score and link map score_text = y[0,:,:,0].numpy() score_link = y[0,:,:,1].numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) #print("score") #print(ret_score_text.shape) cv2.imwrite(result_folder + filename + "_mask.jpg",ret_score_text) #if show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly): t0 = time.time() # リサイズ img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # 前処理 x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # 順伝播 y, _ = net(x) # スコア・リンクマップの作成 score_text = y[0,:,:,0].cpu().data.numpy() score_link = y[0,:,:,1].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # 後処理 boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # 座標調整 boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # レンダリング結果(オプション) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
def test_net(self, net, image, text_threshold, link_threshold, low_text, poly, refine_net=None): img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, 1280, interpolation=cv.INTER_LINEAR, mag_ratio=1.5) ratio_h = ratio_w = 1 / target_ratio x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) return boxes, polys, ret_score_text
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None, overlap=0.0): t0 = time.time() # resize # img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) # ratio_h = ratio_w = 1 / target_ratio img_resized = image ratio_h = ratio_w = 1 # preprocessing x = imgproc.normalizeMeanVariance(img_resized) # x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] # x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] # # if cuda: # x = x.cuda() # # # forward pass # with torch.no_grad(): # y, feature = net(x) # # # make score and link map # score_text = y[0,:,:,0].cpu().data.numpy() # # if refine_net is None: # score_link = y[0,:,:,1].cpu().data.numpy() # else: # # refine link # with torch.no_grad(): # y_refiner = refine_net(y, feature) # # score_link = y_refiner[0,:,:,0].cpu().data.numpy() split_coord = [] if overlap > 0.0 and overlap < 1.0: x, split_coord = splitOverlap(x, overlap) x = torch.from_numpy(x).permute(0, 3, 1, 2) # [h, w, c] to [c, h, w] # x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = joinOverlap(y[:, :, :, 0].cpu().data.numpy(), split_coord) if refine_net is None: score_link = joinOverlap(y[:, :, :, 1].cpu().data.numpy(), split_coord) else: # refine link with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = joinOverlap(y_refiner[:, :, :, 0].cpu().data.numpy(), split_coord) t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
score_region, score_link, conf_map = generate_gt(net, image, gt_boxes, gt_words, args) torch.save(score_region, label_dir + 'region.pt') torch.save(score_link, label_dir + 'link.pt') torch.save(conf_map, label_dir + 'conf.pt') if __name__ == '__main__': import ocr score_region = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/region.pt') score_link = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/link.pt') conf_map = torch.load('/home/ubuntu/Kyumin/craft/data/IC13/labels/train/100/conf.pt') image = imgproc.loadImage('/home/ubuntu/Kyumin/Autotation/data/IC13/images/train/100.jpg') print(score_region.shape, score_link.shape, conf_map.shape) # cv2.imshow('original', image) cv2.imshow('region', imgproc.cvt2HeatmapImg(score_region)) cv2.imshow('link', score_link) cv2.imshow('conf', conf_map) net = CRAFT().cuda() net.load_state_dict(test.copyStateDict(torch.load('weights/craft_mlt_25k.pth'))) net.eval() _, _, ref_text, ref_link, _ = test.test_net(net, image, ocr.argument_parser().parse_args()) cv2.imshow('ref text', imgproc.cvt2HeatmapImg(ref_text)) cv2.imshow('ref link', ref_link) cv2.waitKey(0) cv2.destroyAllWindows()
def main(): import os os.makedirs('result', exist_ok=True) text_render.prepare_renderer() with open('alphabet-all-v5.txt', 'r') as fp: dictionary = [s[:-1] for s in fp.readlines()] model_ocr = OCR(dictionary, 768) model_ocr.load_state_dict(torch.load('ocr.ckpt', map_location='cpu'), strict=False) model_ocr.eval() model = CRAFT_net() sd = torch.load('detect.ckpt', map_location='cpu') model.load_state_dict(sd['model']) model = model.cpu() model.eval() img = cv2.imread(args.image) img_bbox = np.copy(img) img_bbox_all = np.copy(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio( img, args.size, cv2.INTER_LINEAR, mag_ratio=1) img_to_overlay = np.copy(img_resized) ratio_h = ratio_w = 1 / target_ratio img_resized = imgproc.normalizeMeanVariance(img_resized) print(img_resized.shape) rscore, ascore, mask = test(model, img_resized) overlay = imgproc.cvt2HeatmapImg(rscore + ascore) boxes, polys = craft_utils.getDetBoxes(rscore, ascore, args.text_threshold, args.link_threshold, args.low_text, False) boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h, ratio_net=2) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] # merge textlines polys = merge_bboxes(polys, can_merge_textline) for [tl, tr, br, bl] in polys: x = int(tl[0]) y = int(tl[1]) width = int(tr[0] - tl[0]) height = int(br[1] - tr[1]) cv2.rectangle(img_bbox_all, (x, y), (x + width, y + height), color=(255, 0, 0), thickness=2) # run OCR for each textline textlines = run_ocr(img_bbox, polys, dictionary, model_ocr, 32) # merge textline to text region, filter textlines without characters text_regions: List[BBox] = [] new_textlines = [] for (poly_regions, textline_indices, majority_dir) in merge_bboxes_text_region(textlines): [tl, tr, br, bl] = poly_regions x = int(tl[0]) - 5 y = int(tl[1]) - 5 width = int(tr[0] - tl[0]) + 10 height = int(br[1] - tr[1]) + 10 text = '' logprob_lengths = [] for textline_idx in textline_indices: if not text: text = textlines[textline_idx].text else: last_ch = text[-1] cur_ch = textlines[textline_idx].text[0] if ord(last_ch) > 255 and ord(cur_ch) > 255: text += textlines[textline_idx].text else: text += ' ' + textlines[textline_idx].text logprob_lengths.append((np.log(textlines[textline_idx].prob), len(textlines[textline_idx].text))) vc = count_valuable_text(text) total_logprobs = 0.0 for (logprob, length) in logprob_lengths: total_logprobs += logprob * length total_logprobs /= sum([x[1] for x in logprob_lengths]) # filter text region without characters if vc > 1: region = BBox(x, y, width, height, text, np.exp(total_logprobs)) region.textline_indices = [] region.majority_dir = majority_dir text_regions.append(region) for textline_idx in textline_indices: region.textline_indices.append(len(new_textlines)) new_textlines.append(textlines[textline_idx]) textlines = new_textlines # create mask from text_mask_utils import filter_masks, main_process mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR) if pad_h > 0: mask_resized = mask_resized[:-pad_h, :] elif pad_w > 0: mask_resized = mask_resized[:, :-pad_w] mask_resized = cv2.resize(mask_resized, (img.shape[1] // 2, img.shape[0] // 2), interpolation=cv2.INTER_LINEAR) img_resized_2 = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2), interpolation=cv2.INTER_LINEAR) mask_resized[mask_resized > 250] = 255 text_lines = [(a.x // 2, a.y // 2, a.w // 2, a.h // 2) for a in textlines] mask_ccs, cc2textline_assignment = filter_masks(mask_resized, text_lines) cv2.imwrite('result/mask_filtered.png', reduce(cv2.bitwise_or, mask_ccs)) final_mask, textline_colors = main_process(img_resized_2, mask_ccs, text_lines, cc2textline_assignment) final_mask = cv2.resize(final_mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR) # run inpainting img_inpainted = run_inpainting(img, final_mask) # translate text region texts texts = '\n'.join([r.text for r in text_regions]) trans_ret = baidu_translator.translate('ja', 'zh-CN', texts) translated_sentences = [] batch = len(text_regions) if len(trans_ret) < batch: translated_sentences.extend(trans_ret) translated_sentences.extend([''] * (batch - len(trans_ret))) elif len(trans_ret) > batch: translated_sentences.extend(trans_ret[:batch]) else: translated_sentences.extend(trans_ret) # render translated texts img_canvas = np.copy(img_inpainted) for trans_text, region in zip(translated_sentences, text_regions): print(region.text) print(trans_text) print(region.majority_dir, region.x, region.y, region.w, region.h) img_bbox = cv2.rectangle(img_bbox, (region.x, region.y), (region.x + region.w, region.y + region.h), color=(0, 0, 255), thickness=2) for idx in region.textline_indices: txtln = textlines[idx] img_bbox = cv2.rectangle(img_bbox, (txtln.x, txtln.y), (txtln.x + txtln.w, txtln.y + txtln.h), color=textline_colors[idx], thickness=2) if region.majority_dir == 'h': text_render.put_text_horizontal(img_canvas, trans_text, len(region.textline_indices), region.x, region.y, region.w, region.h, textline_colors[idx], None) else: text_render.put_text_vertical(img_canvas, trans_text, len(region.textline_indices), region.x, region.y, region.w, region.h, textline_colors[idx], None) cv2.imwrite('result/rs.png', imgproc.cvt2HeatmapImg(rscore)) cv2.imwrite('result/as.png', imgproc.cvt2HeatmapImg(ascore)) cv2.imwrite('result/textline.png', overlay) cv2.imwrite('result/bbox.png', img_bbox) cv2.imwrite('result/bbox_unfiltered.png', img_bbox_all) cv2.imwrite( 'result/overlay.png', cv2.cvtColor( overlay_image( img_to_overlay, cv2.resize(overlay, (img_resized.shape[1], img_resized.shape[0]), interpolation=cv2.INTER_LINEAR)), cv2.COLOR_RGB2BGR)) cv2.imwrite('result/mask.png', final_mask) cv2.imwrite('result/masked.png', cv2.cvtColor(img_inpainted, cv2.COLOR_RGB2BGR)) cv2.imwrite('result/final.png', cv2.cvtColor(img_canvas, cv2.COLOR_RGB2BGR))
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, image_path, refine_net=None): t0 = time.time() img_h, img_w, c = image.shape # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio h, w, c = image.shape # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() #리전 스코어 Region score score_link = y[0, :, :, 1].cpu().data.numpy() #어피니티 스코어 # refine link if refine_net is not None: y_refiner = refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, 0.4, poly) # CRAFT에서 박스를 그려주는 부분 # # coordinate adjustment #좌표설정 boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) #print(scores) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() ret_score_text = imgproc.cvt2HeatmapImg(render_img) Plus_score_text = imgproc.cvMakeScores(render_img) ## filename, file_ext = os.path.splitext(os.path.basename(image_path)) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) post_folder = './output/post' # 원본이미지를 이진화한 이미지 저장 resize_folder = './output/resize' # resize된 원본 이미지 저장 if not os.path.isdir(resize_folder + '/'): os.makedirs(resize_folder + '/') resize_file = resize_folder + "/resize_" + filename + '_mask.jpg' #오리지널 이미지 IMG_RGB2 = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) #craft에서 resize한 이미지를 RGB로 컨버트 # 합성 이미지를 만들기 위한 부분 pil_image = Image.fromarray((IMG_RGB2 * 255).astype(np.uint8)) images = np.array(pil_image) images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY) ret, thresh = cv2.threshold(images, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) #+ cv2.THRESH_OTSU # 이미지 합성을 위해 이진화 text_score = cv2.resize(Plus_score_text, None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR) # 다시 원본 사이즈로 조절 thresh = cv2.resize(thresh, (img_w, img_h)) # 원본 이진화 이미지 text_score = cv2.resize(text_score, (img_w, img_h)) # Region 스코어 이진화 이미지 text_score = Image.fromarray((text_score).astype(np.uint8)) text_score = np.array(text_score) if not os.path.isdir('./output/og_bri' + '/'): # 원본 이진화 이미지 저장 폴더 os.makedirs('./output/og_bri' + '/') if not os.path.isdir('./output/score/'): # 스코어 이진화 이미지 저장 폴더 os.makedirs('./output/score/') cv2.imwrite('./output/og_bri' + "/og_" + filename + '.jpg', thresh) # 원본 이진화 이미지 저장 cv2.imwrite('./output/score' + "/score_" + filename + '.jpg', text_score) # 스코어 이진화 이미지 저장 img_h = thresh.shape[0] img_w = thresh.shape[1] IMG_RGB2 = cv2.resize(IMG_RGB2, (img_w, img_h)) # 다시 원본 사이즈로 resize cv2.imwrite(resize_file, IMG_RGB2) return boxes, polys, ret_score_text
target = np.zeros([height, width], dtype=np.uint8) affinities = [] for i in range(len(words)): character_bbox = np.array(bboxes[i]) total_letters = 0 for char_num in range(character_bbox.shape[0] - 1): target, affinity = self.add_affinity( target, character_bbox[total_letters], character_bbox[total_letters + 1]) affinities.append(affinity) total_letters += 1 if len(affinities) > 0: affinities = np.concatenate(affinities, axis=0) return target, affinities if __name__ == '__main__': gaussian = GaussianTransformer(1024, 1.5) gaussian.saveGaussianHeat() bbox = np.array([[[1, 200], [510, 200], [510, 510], [1, 510]]]) print(bbox.shape) bbox = bbox.transpose((2, 1, 0)) print(bbox.shape) weight, target = gaussian.generate_target((1024, 1024, 3), bbox.copy()) target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(weight.copy() / 255) cv2.imshow('test', target_gaussian_heatmap_color) cv2.waitKey() cv2.imwrite("test.jpg", target_gaussian_heatmap_color)
def detect_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net, res_path): t0 = time.time() origin_image_1_channel = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) origin_image_3_color = np.array(image) # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() cv2.imwrite("core_link.jpg", score_text * 255) cv2.imwrite("score_link.jpg", score_link * 255) # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing # 获取CRAFT生成的框 boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) '处理裂开的box,相邻的放在同一组' # 广度优先合并相邻的框 # 距离矩阵构建 all_rect_cx_cy = np.zeros((len(boxes), 2)) for i in range(len(boxes)): box = boxes[i] left = min(box[0][0], box[1][0], box[2][0], box[3][0]) right = max(box[0][0], box[1][0], box[2][0], box[3][0]) top = min(box[0][1], box[1][1], box[2][1], box[3][1]) bottom = max(box[0][1], box[1][1], box[2][1], box[3][1]) top = int(top) bottom = int(bottom) left = int(left) right = int(right) all_rect_cx_cy[i][0] = ((left + right) / 2) / 4 # 减少x轴的影响 # 还需调整 all_rect_cx_cy[i][1] = ((top + bottom) / 2) mat_distance = [] for i in range(len(all_rect_cx_cy)): mat_distance.append( np.sqrt(np.sum((all_rect_cx_cy - all_rect_cx_cy[i])**2, axis=-1))) print("generate distance mat;len:", len(mat_distance)) segment_group = [] ind_group = -1 search_queue = deque() cnt_processed = 0 processed = set() # 广度优先 while cnt_processed < len(all_rect_cx_cy): # 只要搜索队列中有数据就一直遍历下去 if (len(search_queue) == 0): for i in range(len(all_rect_cx_cy)): if (i not in processed): search_queue.append(i) segment_group.append([]) ind_group += 1 break current_node = search_queue.popleft() # 从队列前边获取节点,即先进先出,这是BFS的核心 if current_node not in processed: # 当前节点是否被访问过 cnt_processed += 1 processed.add(current_node) inds = np.argsort(mat_distance[current_node]) segment_group[ind_group].append(boxes[current_node]) cnt_company = 0 distance_threshold = 20 # max(all_rect[current_node][2],all_rect[current_node][3]) # print(distance_threshold) for index in inds: # 遍历相邻节点,判断相邻节点是否已经在搜索队列 if mat_distance[current_node][index] > distance_threshold: break cnt_company += 1 if cnt_company > 200: print("error") exit() if index not in search_queue: # 如果相邻节点不在搜索队列则进行添加 search_queue.append(index) '合并在同一组的框' merge_boxes = [] for segment in segment_group: left_s = [] right_s = [] top_s = [] bottom_s = [] for box in segment: left = min(box[0][0], box[1][0], box[2][0], box[3][0]) right = max(box[0][0], box[1][0], box[2][0], box[3][0]) top = min(box[0][1], box[1][1], box[2][1], box[3][1]) bottom = max(box[0][1], box[1][1], box[2][1], box[3][1]) top = math.floor(top) bottom = math.floor(bottom) left = math.floor(left) right = math.floor(right) left_s.append(left) right_s.append(right) top_s.append(top) bottom_s.append(bottom) merge_boxes.append( [min(left_s), min(top_s), max(right_s), max(bottom_s)]) json_record = [] for rect in merge_boxes: threshold_hw = min(rect[3] - rect[1], rect[2] - rect[0]) * 0.2 crop = origin_image_1_channel[rect[1]:rect[3], rect[0]:rect[2]] # debug_write(crop,"exp"); # adaptiveThreshold binary_img = cv2.adaptiveThreshold(crop, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 31, 10) debug_write(binary_img, "all") # ret, binary_img = cv2.threshold(crop, 175, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # debug_write(binary_img,"dilate") # kernel = np.ones((1, 2), np.uint8) # binary_img_dilate = cv2.erode(binary_img, kernel, iterations=1) # debug_write(binary_img_dilate,"dilate") # print(binary_img.max(),binary_img.min()) _, contours, _ = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) group = [] for i in range(len(contours)): rect_char = cv2.boundingRect(contours[i]) group.append(rect_char) group.sort(key=lambda rect: rect[0]) if (len(group) >= 1): last_x_start = group[0][0] last_x_end = group[0][0] + group[0][2] last = group[0] i = 1 '合并=/等符号' while i < len(group) and i >= 1: now = group[i] cx = now[0] + now[2] / 2 cy = now[1] + now[3] / 2 last_cy = last[1] + last[3] / 2 y_near = abs(last_cy - cy) < (last_x_end - last_x_start) * 0.6 if (last_x_start < cx and cx < last_x_end and y_near): group.pop(i) i -= 1 x1 = min(now[0], group[i][0]) y1 = min(now[1], group[i][1]) x2 = max(now[0] + now[2], group[i][0] + group[i][2]) y2 = max(now[1] + now[3], group[i][1] + group[i][3]) group[i] = (x1, y1, x2 - x1, y2 - y1) else: last_x_start = group[i][0] last_x_end = group[i][0] + group[i][2] last = group[i] i += 1 # if(len(group)<4 or len(group)>16): # continue '检测每个框及其结果' json_record_perline = [] rect_set = [] res_set = [] # def detect_rect(rect_char, binary_img): # # crop_char = binary_img[ # rect_char[1]: # rect_char[1] + rect_char[3], # rect_char[0]: # rect_char[0] + rect_char[2]] # # debug_crop_char = crop_char # if crop_char.shape[0]*6 < crop_char.shape[1]: # return '-' # if crop_char.shape[0] < 2 or crop_char.shape[1] < 2: # return '' # debug_write(crop_char, "detect_rect") # crnn_text_result = recognizer(crop_char) # # crop_char = torch.tensor(crop_char, dtype=torch.int) # # crop_char = adapt_size(crop_char) # crop_char = crop_char.float().to(device) # res = classifer_box.eval(crop_char.unsqueeze(0)).squeeze().int().item() # # print(config.CLASS[res], crnn_text_result) # # return config.CLASS_toString[res] def detect_rect(rect_char, binary_img, before_str): crop_char = binary_img[rect_char[1]:rect_char[1] + rect_char[3], rect_char[0]:rect_char[0] + rect_char[2]] # 减号 # print(crop_char.shape) # if crop_char.shape[0] * 3 < crop_char.shape[1] and crop_char.mean() > 128: # return '-' # if crop_char.shape[1] * 3 < crop_char.shape[0] and crop_char.mean() > 128: # return '1' # 区域过小 if crop_char.shape[0] < 2 and crop_char.shape[1] < 2: return '' # debug_write(crop_char, "detect_rect") # if crop_char.shape[1] < crop_char.shape[0] // 2: # fx = 4 # else: # fx = fy # crnn crnn_text_result = recognizer(crop_char) # debug_write(crop_char,crnn_text_result.replace('/','d')) # dense # crop_char = torch.tensor(crop_char, dtype=torch.int) # crop_char = adapt_size(crop_char) # crop_char = crop_char.float().to(device) # res = classifer_box.eval(crop_char.unsqueeze(0)).squeeze().int().item() # print(crnn_text_result,compress(crnn_text_result)) # print(crnn_text_result) return compress(crnn_text_result) res_str = '' for i in range(len(group)): rect_char = group[i] if max(rect_char[2], rect_char[3]) < threshold_hw: continue res = detect_rect(rect_char, binary_img, before_str=res_str) res_set.append(res) rect_set.append(rect_char) res_str += res print(res_str) # for i in range(len(res_set)): # res = res_set[i] # res_str += config.CLASS_toString[res] # # json_record_perline.append({'rect_char': rect_set[i], 'char': config.CLASS_toString[res]}) # # # print('left',res) # '等号右边颜色浅 针对右边进行二值化后重新检测' # if (config.CLASS_is_eq(res)): # rect_char = rect_set[i] # # crop = origin_image_1_channel[rect[1]:rect[3], rect[0]:rect[2]][:, rect_char[0] + rect_char[2]:] # # # 记录相对位置 # relative = (rect_char[0] + rect_char[2], 0, 0, 0) # # if (crop.shape[0] * crop.shape[1] < 4): # break # # 自适应算法 # # crop = convert_to_binary_inv(crop) # crop = cv2.adaptiveThreshold(crop, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # cv2.THRESH_BINARY_INV, 31, 10) # # debug_write(crop,'') # # _, contours_right, _ = cv2.findContours(crop, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # # group_right = [] # for i in range(len(contours_right)): # rect_char_right = cv2.boundingRect(contours_right[i]) # group_right.append(rect_char_right) # group_right.sort(key=lambda rect: rect[0]) # for rect_char in group_right: # if (max(rect_char[2], rect_char[3]) < crop.shape[0] * 0.3): # continue # res_right = detect_rect(rect_char, crop) # res_str += config.CLASS_toString[res_right] # json_record_perline.append({'rect_char': ( # relative[0] + rect_char[0], # relative[1] + rect_char[1], # rect_char[2], # rect_char[3] # ), 'char': config.CLASS_toString[res_right]}) # # break eq = res_str.split('=') if (len(eq) >= 2): res_str = res_str.replace("/", "d") json_record.append({ 'rect_expression': (rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]), 'expression': json_record_perline }) with open("resjson/" + res_str + ".json", 'w') as file_object: file_object.write( json.dumps({ 'rect_expression': (rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]), 'expression': json_record_perline })) if str_to_num(eq[0]) == str_to_num(eq[-1]): # cv2.rectangle(origin_image_3_color, (rect[0], rect[1]), (rect[2] , rect[3]), (46,255,87), 2) cv2.line(origin_image_3_color, (rect[0], rect[3]), (rect[2], rect[3]), (46, 255, 87), 2) cv2.imwrite( './res/' + res_str + '.png', origin_image_1_channel[rect[1]:rect[3], rect[0]:rect[2]]) elif eq[-1] == "": cv2.rectangle(origin_image_3_color, (rect[0], rect[1]), (rect[2], rect[3]), (255, 46, 87), 2) cv2.imwrite( './res/O' + res_str + '.png', origin_image_1_channel[rect[1]:rect[3], rect[0]:rect[2]]) else: cv2.rectangle(origin_image_3_color, (rect[0], rect[1]), (rect[2], rect[3]), (46, 87, 255), 2) cv2.imwrite( './res/X' + res_str + '.png', origin_image_1_channel[rect[1]:rect[3], rect[0]:rect[2]]) print(res_path) cv2.imwrite(res_path, origin_image_3_color) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) cv2.imwrite("xxxx.png", ret_score_text) # for line in json_record: # print(line) data2 = json.dumps(json_record) return data2
def test(modelpara): # load net net = CRAFT() # initialize print('Loading weights from checkpoint {}'.format(modelpara)) if args.cuda: net.load_state_dict(copyStateDict(torch.load(modelpara))) else: net.load_state_dict( copyStateDict(torch.load(modelpara, map_location='cpu'))) if args.cuda: net = net.cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = False net.eval() t = time.time() # load data for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\n') image = imgproc.loadImage(image_path) res = image.copy() # bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) gh_pred, bboxes_pred, polys_pred, size_heatmap = test_net( net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) filename, file_ext = os.path.splitext(os.path.basename(image_path)) result_dir = os.path.join(result_folder, filename) os.makedirs(result_dir, exist_ok=True) for gh_img, field in zip(gh_pred, CLASSES): img = imgproc.cvt2HeatmapImg(gh_img) img_path = os.path.join(result_dir, 'res_{}_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) h, w = image.shape[:2] img = cv2.resize(image, size_heatmap)[::, ::, ::-1] img_path = os.path.join(result_dir, 'res_{}.jpg'.format(filename, field)) cv2.imwrite(img_path, img) # # save score text # filename, file_ext = os.path.splitext(os.path.basename(image_path)) # mask_file = result_folder + "/res_" + filename + '_mask.jpg' # cv2.imwrite(mask_file, score_text) res = cv2.resize(res, size_heatmap) for polys, field in zip(polys_pred, CLASSES): TEXT_WIDTH = 10 * len(field) + 10 TEXT_HEIGHT = 15 polys = np.int32([poly.reshape((-1, 1, 2)) for poly in polys]) res = cv2.polylines(res, polys, True, (0, 0, 255), 2) for poly in polys: poly[1, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1]] poly[2, 0] = [poly[0, 0, 0] - 10, poly[0, 0, 1] + TEXT_HEIGHT] poly[3, 0] = [ poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1] + TEXT_HEIGHT ] poly[0, 0] = [poly[0, 0, 0] - TEXT_WIDTH, poly[0, 0, 1]] res = cv2.fillPoly(res, polys, (224, 224, 224)) # print(poly) for poly in polys: res = cv2.putText(res, field, tuple(poly[3, 0] + [+5, -5]), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), thickness=1) res_file = os.path.join(result_dir, 'res_{}_bbox.jpg'.format(filename, field)) cv2.imwrite(res_file, res[::, ::, ::-1]) # break # file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) print("elapsed time : {}s".format(time.time() - t))
def train(args): file_utils.mkdir(dir=[args.save_models]) if args.vis_train: file_utils.mkdir(dir=['./vis/']) ''' MAKE DATASET ''' datasets = webtoon_dataset(opt.DETECTION_TRAIN_IMAGE_PATH, opt.DETECTION_TRAIN_LABEL_PATH, args.train_size) train_data_loader = DataLoader(datasets, batch_size=args.batch, shuffle=True) ''' INITIALIZE MODEL, GPU, OPTIMIZER, and, LOSS ''' model = LTD() model = torch.nn.DataParallel(model).cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lr_decay_gamma) criterion = LTD_LOSS() step_idx = 0 model.train() print('TEXT DETECTION TRAINING KICK-OFF]') ''' KICK OFF TRAINING PROCESS ''' for e in range(args.epoch): start = time.time() ''' LOAD MATERIAL FOR TRAINING FROM DATALOADER ''' for idx, (image, region_score_GT, affinity_score_GT, confidence) in enumerate(train_data_loader): ''' ADJUST LEARNING RATE PER 20000 ITERATIONS ''' if idx % args.lr_decay_step == 0 and idx != 0: step_idx += 1 #adjust_learning_rate(optimizer, args.lr, step_idx) ''' CONVERT NUMPY => TORCH ''' images = Variable(image.type(torch.FloatTensor)).cuda() region_score_GT = Variable(region_score_GT.type(torch.FloatTensor)).cuda() affinity_score_GT = Variable(affinity_score_GT.type(torch.FloatTensor)).cuda() confidence = Variable(confidence.type(torch.FloatTensor)).cuda() ''' PASS THE MODEL AND PREDICT SCORES ''' y, _ = model(images) score_region = y[:, :, :, 0].cuda() score_affinity = y[:, :, :, 1].cuda() if args.vis_train: if idx % 20 == 0 and idx != 0 and e % 2 == 0: for idx2 in range(args.batch): render_img1 = score_region[idx2].cpu().detach().numpy().copy() render_img2 = score_affinity[idx2].cpu().detach().numpy().copy() render_img = np.hstack((render_img1, render_img2)) render_img = imgproc.cvt2HeatmapImg(render_img) cv2.imwrite('./vis/e' + str(e) + '-s' + str(idx) + '-' + str(idx2) + '.jpg', render_img) ''' CALCULATE LOSS VALUE AND UPDATE WEIGHTS ''' optimizer.zero_grad() loss = criterion(region_score_GT, affinity_score_GT, score_region, score_affinity, confidence) loss.backward() optimizer.step() if idx % args.display_interval == 0: end = time.time() print('epoch: {}, iter:[{}/{}], lr:{}, loss: {:.8f}, Time Cost: {:.4f}s'.format(e, idx, len(train_data_loader), args.lr, loss.item(), end - start)) start = time.time() ''' SAVE MODEL PER 2 EPOCH ''' start = time.time() if e % args.save_interval == 0: print('save model ... :' + args.save_models) torch.save(model.module.state_dict(), args.save_models + 'ltd' + repr(e) + '.pth')
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): t0 = time.time() # resize img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = imgproc.normalizeMeanVariance(img_resized) x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] if cuda: x = x.cuda() # forward pass with torch.no_grad(): y, feature = net(x) # make score and link map score_text = y[0, :, :, 0].cpu().data.numpy() score_link = y[0, :, :, 1].cpu().data.numpy() # refine link if refine_net is not None: with torch.no_grad(): y_refiner = refine_net(y, feature) score_link = y_refiner[0, :, :, 0].cpu().data.numpy() t0 = time.time() - t0 t1 = time.time() # Post-processing boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) # coordinate adjustment boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) '处理裂开的box,相邻的放在同一组' # 广度优先合并相邻的框 # 距离矩阵构建 all_rect_cx_cy = np.zeros((len(boxes), 2)) for i in range(len(boxes)): box = boxes[i] left = min(box[0][0], box[1][0], box[2][0], box[3][0]) right = max(box[0][0], box[1][0], box[2][0], box[3][0]) top = min(box[0][1], box[1][1], box[2][1], box[3][1]) bottom = max(box[0][1], box[1][1], box[2][1], box[3][1]) top = int(top) bottom = int(bottom) left = int(left) right = int(right) all_rect_cx_cy[i][0] = ((left + right) / 2) / 4 #减少x轴的影响 #还需调整 all_rect_cx_cy[i][1] = ((top + bottom) / 2) mat_distance = [] for i in range(len(all_rect_cx_cy)): mat_distance.append( np.sqrt(np.sum((all_rect_cx_cy - all_rect_cx_cy[i])**2, axis=-1))) print("generate distance mat;len:", len(mat_distance)) segment_group = [] ind_group = -1 search_queue = deque() cnt_processed = 0 processed = set() #广度优先 while cnt_processed < len(all_rect_cx_cy): # 只要搜索队列中有数据就一直遍历下去 if (len(search_queue) == 0): for i in range(len(all_rect_cx_cy)): if (i not in processed): search_queue.append(i) segment_group.append([]) ind_group += 1 break current_node = search_queue.popleft() # 从队列前边获取节点,即先进先出,这是BFS的核心 if current_node not in processed: # 当前节点是否被访问过 cnt_processed += 1 processed.add(current_node) inds = np.argsort(mat_distance[current_node]) segment_group[ind_group].append(boxes[current_node]) cnt_company = 0 distance_threshold = 20 #max(all_rect[current_node][2],all_rect[current_node][3]) # print(distance_threshold) for index in inds: # 遍历相邻节点,判断相邻节点是否已经在搜索队列 if mat_distance[current_node][index] > distance_threshold: break cnt_company += 1 if cnt_company > 200: print("error") exit() if index not in search_queue: # 如果相邻节点不在搜索队列则进行添加 search_queue.append(index) '合并在同一组的框' merge_boxes = [] for segment in segment_group: left_s = [] right_s = [] top_s = [] bottom_s = [] for box in segment: left = min(box[0][0], box[1][0], box[2][0], box[3][0]) right = max(box[0][0], box[1][0], box[2][0], box[3][0]) top = min(box[0][1], box[1][1], box[2][1], box[3][1]) bottom = max(box[0][1], box[1][1], box[2][1], box[3][1]) top = math.floor(top) bottom = math.floor(bottom) left = math.floor(left) right = math.floor(right) left_s.append(left) right_s.append(right) top_s.append(top) bottom_s.append(bottom) merge_boxes.append( [min(left_s), min(top_s), max(right_s), max(bottom_s)]) for rect in merge_boxes: threshold_hw = min(rect[3] - rect[1], rect[2] - rect[0]) * 0.2 crop = i_image[rect[1]:rect[3], rect[0]:rect[2]] ret, binary_img = cv2.threshold( crop, 175, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) _, contours, _ = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) group = [] for i in range(len(contours)): rect_char = cv2.boundingRect(contours[i]) group.append(rect_char) group.sort(key=lambda rect: rect[0]) last_x_start = group[0][0] last_x_end = group[0][0] + group[0][2] last = group[0] i = 1 '合并=/等符号' while i < len(group) and i >= 1: now = group[i] cx = now[0] + now[2] / 2 cy = now[1] + now[3] / 2 last_cy = last[1] + last[3] / 2 y_near = abs(last_cy - cy) < (last_x_end - last_x_start) * 0.6 if (last_x_start < cx and cx < last_x_end and y_near): group.pop(i) i -= 1 x1 = min(now[0], group[i][0]) y1 = min(now[1], group[i][1]) x2 = max(now[0] + now[2], group[i][0] + group[i][2]) y2 = max(now[1] + now[3], group[i][1] + group[i][3]) group[i] = (x1, y1, x2 - x1, y2 - y1) else: last_x_start = group[i][0] last_x_end = group[i][0] + group[i][2] last = group[i] i += 1 if (len(group) < 4 or len(group) > 16): continue '检测每个框及其结果' rect_set = [] res_set = [] def detect_rect(rect_char, binary_img): crop_char = binary_img[rect_char[1]:rect_char[1] + rect_char[3], rect_char[0]:rect_char[0] + rect_char[2]] crop_char = torch.tensor(crop_char, dtype=torch.int) crop_char = adapt_size(crop_char) crop_char = crop_char.float().cuda() res = classifer_box.eval( crop_char.unsqueeze(0)).squeeze().int().item() debug_write( crop_char[0].cpu().int().numpy().astype(np.uint8) * 255, config.CLASS_toString[res]) return res for i in range(len(group)): rect_char = group[i] if max(rect_char[2], rect_char[3]) < threshold_hw: continue res = detect_rect(rect_char, binary_img) res_set.append(res) rect_set.append(rect_char) res_str = '' for i in range(len(res_set)): res = res_set[i] res_str += config.CLASS_toString[res] # print('left',res) '等号右边颜色浅 针对右边进行二值化后重新检测' if (config.CLASS_is_eq(res)): rect_char = rect_set[i] crop = i_image[rect[1]:rect[3], rect[0]:rect[2]][:, rect_char[0] + rect_char[2]:] if (crop.shape[0] * crop.shape[1] < 4): break crop = convert_to_binary_inv(crop) debug_write(crop, '') _, contours_right, _ = cv2.findContours( crop, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) group_right = [] for i in range(len(contours_right)): rect_char_right = cv2.boundingRect(contours_right[i]) group_right.append(rect_char_right) group_right.sort(key=lambda rect: rect[0]) for rect_char in group_right: if (max(rect_char[2], rect_char[3]) < crop.shape[0] * 0.3): continue res_right = detect_rect(rect_char, crop) res_str += config.CLASS_toString[res_right] break eq = res_str.split('=') if (len(eq) == 2): global i_image_3_color res_str = res_str.replace("/", "d") print(res_str) if str_to_num(eq[0]) == str_to_num(eq[1]): cv2.rectangle(i_image_3_color, (rect[0], rect[1]), (rect[2], rect[3]), (46, 255, 87), 2) cv2.imwrite('./res/' + res_str + '.png', i_image[rect[1]:rect[3], rect[0]:rect[2]]) elif eq[1] == "": cv2.rectangle(i_image_3_color, (rect[0], rect[1]), (rect[2], rect[3]), (46, 87, 255), 2) cv2.imwrite('./res/' + res_str + '.png', i_image[rect[1]:rect[3], rect[0]:rect[2]]) else: cv2.rectangle(i_image_3_color, (rect[0], rect[1]), (rect[2], rect[3]), (255, 46, 87), 2) cv2.imwrite('./res/x_' + res_str + '.png', i_image[rect[1]:rect[3], rect[0]:rect[2]]) # print(str_to_num(eq[0]) # print(str_to_num(eq[1]) # cv2.imwrite('./res/'+res_str+'.png', binary_img) polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) for k in range(len(polys)): if polys[k] is None: polys[k] = boxes[k] cv2.imshow('', i_image_3_color) cv2.waitKey() t1 = time.time() - t1 # render results (optional) render_img = score_text.copy() render_img = np.hstack((render_img, score_link)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) return boxes, polys, ret_score_text
if __name__ == '__main__': gaussian = GaussianTransformer(512, 0.4, 0.2) gaussian.saveGaussianHeat() gaussian._test() bbox0 = np.array([[[0, 0], [100, 0], [100, 100], [0, 100]]]) image = np.zeros((500, 500), np.uint8) # image = gaussian.add_region_character(image, bbox) bbox1 = np.array([[[100, 0], [200, 0], [200, 100], [100, 100]]]) bbox2 = np.array([[[100, 100], [200, 100], [200, 200], [100, 200]]]) bbox3 = np.array([[[0, 100], [100, 100], [100, 200], [0, 200]]]) bbox4 = np.array([[[96, 0], [151, 9], [139, 64], [83, 58]]]) # image = gaussian.add_region_character(image, bbox) # print(image.max()) image = gaussian.generate_region((500, 500, 1), [bbox4]) target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(image.copy() / 255) cv2.imshow("test", target_gaussian_heatmap_color) cv2.imwrite("test.jpg", target_gaussian_heatmap_color) cv2.waitKey() # weight, target = gaussian.generate_target((1024, 1024, 3), bbox.copy()) # target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(weight.copy() / 255) # cv2.imshow('test', target_gaussian_heatmap_color) # cv2.waitKey() # cv2.imwrite("test.jpg", target_gaussian_heatmap_color) # # coding=utf-8 # from math import exp # import numpy as np # import cv2 # import os # import imgproc
# Post-processing boxes, polys = craft_utils.getDetBoxes(region, affinity, text_threshold, link_threshold, low_text, poly) # coordinate adjustment # boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) # polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) # for k in range(len(polys)): # if polys[k] is None: # polys[k] = boxes[k] # render results (optional) render_img = region.copy() render_img = np.hstack((render_img, affinity)) ret_score_text = imgproc.cvt2HeatmapImg(render_img) for i, box in enumerate(boxes): _, (kernel_w, kernel_h), _ = cv2.minAreaRect( box) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度) kernel_w, kernel_h = int(kernel_w), int(kernel_h) if kernel_w < kernel_h: kernel_w, kernel_h = kernel_h, kernel_w box = np.array(box).astype(np.int32).reshape((-1)) box = box.reshape(-1, 2) # cv2.polylines(image, [box.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) # 将高斯核透视变换,坐标(列,行)[box.reshape((-1, 1, 2))] src = np.float32(box) # 左上,左下,右下,右上 tgt = np.float32([(0, 0), (kernel_w, 0), (kernel_w, kernel_h), (0, kernel_h)]) M = cv2.getPerspectiveTransform(src, tgt)