def process_boxes(images, im_data, iou_pred, roi_pred, angle_pred, score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug = False): '''iou_pred:类别 roi_pred:上,下,左,右 angle_pred:预测的角度 score_maps:目标分类label''' gt_good = 0 gt_proc = 0 # 0. 对预测的rroi进行筛选,选出满足要求的rroi,求误差//对一个batch进行操作 rrois = [] labels = [] for bid in range(iou_pred.size(0)): gts = gtso[bid] # 文字区域 lbs = lbso[bid] # 文字 gt_proc = 0 gt_good = 0 gts_count = {} # 每个rroi区域的计数,避免一个rroi多次用于训练 iou_pred_np = iou_pred[bid].data.cpu().numpy() iou_map = score_maps[bid] to_walk = iou_pred_np.squeeze(0) * iou_map * (iou_pred_np.squeeze(0) > 0.5) # 预测为文本区域且实际也为文本的索引点 roi_p_bid = roi_pred[bid].data.cpu().numpy() # 预测出来的roi gt_idx = gt_idxs[bid] # 第i个文字区域的label if debug: img = images[bid] img += 1 img *= 128 img = np.asarray(img, dtype=np.uint8) xy_text = np.argwhere(to_walk > 0) # 返回的顺序为:(行,列) random.shuffle(xy_text) xy_text = xy_text[0:min(xy_text.shape[0], 100)] # 选出100个点 # 对预测的点进行crop for i in range(0, xy_text.shape[0]): if opts.geo_type == 1: break pos = xy_text[i, :] gt_id = gt_idx[pos[0], pos[1]] # 每个点对应的目标文字区域label if not gt_id in gts_count: gts_count[gt_id] = 0 # 1. 一个文字区域最多用2次 if gts_count[gt_id] > 2: # 一个目标文字区域最多可以用几次 continue # 2. 文字是‘##’ gt = gts[gt_id] # 当前点预测区域对应的label区域,和label文本 gt_txt = lbs[gt_id] if gt_txt.startswith('##'): continue # 3. 目标文字区域的高度 dhgt = gt[1] - gt[0] # h_gt = math.sqrt(dhgt[0] * dhgt[0] + dhgt[1] * dhgt[1]) # 标注label短边值 if h_gt < 10: continue # 4. 标注的区域超出了图像范围 if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(3): continue # 5. 预测角度和真实角度相差太大 angle_sin = angle_pred[bid, 0, pos[0], pos[1]] angle_cos = angle_pred[bid, 1, pos[0], pos[1]] angle = math.atan2(angle_sin, angle_cos) # 预测的角度和真实的角度 angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2 if math.fabs(angle_gt - angle) > math.pi / 16: # 预测角度和真实角度相差11.25度以上,原始为16 continue # 6. 求倾斜4边形4条边的中点——4个角点 offset = roi_p_bid[:, pos[0], pos[1]] # 得到当前点的rroi,对应的含义为:上下左右 posp = pos + 0.25 # 顺序为h,w——y,x pos_g = np.array([(posp[1] - offset[0] * math.sin(angle)) * 4, (posp[0] - offset[0] * math.cos(angle)) * 4 ]) # 求出的是x,y。xy_test,返回的是(行,列),转换到图像坐标系是(y,x) pos_g2 = np.array([ (posp[1] + offset[1] * math.sin(angle)) * 4, (posp[0] + offset[1] * math.cos(angle)) * 4 ]) pos_r = np.array([(posp[1] - offset[2] * math.cos(angle)) * 4, (posp[0] - offset[2] * math.sin(angle)) * 4 ]) pos_r2 = np.array([(posp[1] + offset[3] * math.cos(angle)) * 4, (posp[0] + offset[3] * math.sin(angle)) * 4 ]) center = (pos_g + pos_g2 + pos_r + pos_r2) / 2 - [4*pos[1], 4*pos[0]] #center = (pos_g + pos_g2 + pos_r + pos_r2) / 4 # 求中心 dw = pos_r - pos_r2 # 长边 dh = pos_g - pos_g2 # 短边 w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) # 长边值 h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) # 短边值 rect = ( (center[0], center[1]), (w, h), angle * 180 / math.pi ) # 预测的矩形,网络输出的是sin和cos,angle为弧度 pts = cv2.boxPoints(rect) # 4边形4个角点的值 # 7. 求倾斜4边形和目标矩形的IOU// 这里是用矩形的方式求的 pred_bbox = cv2.boundingRect(pts) # 倾斜4边形的外接矩形,[x,y,w,h] pred_bbox = [pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3]] # 返回的是 pred_bbox[2] += pred_bbox[0] pred_bbox[3] += pred_bbox[1] gt_bbox = [gt[:, 0].min(), gt[:, 1].min(), gt[:, 0].max(), gt[:, 1].max()] # 目标box的外接矩形 inter = intersect(pred_bbox, gt_bbox) # 交集 uni = union(pred_bbox, gt_bbox) # 并集 ratio = area(inter) / float(area(uni)) # 求两个矩形的交并比 if ratio < 0.9: # 交并比小于0.9则舍弃 continue hratio = min(h, h_gt) / max(h, h_gt) # 高度相差太多 if hratio < 0.5: continue # 8. 将rroi按rroi_align的要求进行整理 angle = -angle / 3.1415926535 * 180 rrois.append([bid, center[0], center[1], h, w, angle]) # 将多个rroi添加在一起 labels.append(gt_txt) gts_count[gt_id] += 1 gt_proc += 1 # 8.1. for debug: 自己读入图片进行测试// 以上为对预测的rroi进行筛选 # img = cv2.imread('./rroi_align/data/timg.jpeg') # gts = [[[206,111],[199,95],[349,60],[355,80]]] # im_data = torch.from_numpy(img).unsqueeze(0).permute(0, 3, 1, 2) # 显示测试 # im_data = im_data.to(torch.float).cuda() # 9. 为了引导误差的收敛方向,每张图片都将标注的rroi crop出来进行训练 if len(gts) != 0: gt = np.asarray(gts) center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] + gt[:, 3, :]) / 4 # 求中心点 dw = gt[:, 2, :] - gt[:, 1, :] dh = gt[:, 1, :] - gt[:, 0, :] poww = pow(dw, 2) powh = pow(dh, 2) w = np.sqrt(poww[:, 0] + poww[:,1]) h = np.sqrt(powh[:,0] + powh[:,1]) + random.randint(-2, 2) angle_gt = ( np.arctan2((gt[:,2,1] - gt[:,1,1]), gt[:,2,0] - gt[:,1,0]) + np.arctan2((gt[:,3,1] - gt[:,0,1]), gt[:,3,0] - gt[:,0,0]) ) / 2 # 求角度 angle_gt = -angle_gt / 3.1415926535 * 180 # 需要加个负号 # 10. 对每个rroi进行判断是否用于训练 for gt_id in range(0, len(gts)): gt_txt = lbs[gt_id] # 文字判断 if gt_txt.startswith('##'): continue gt = gts[gt_id] # 标注信息判断 if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0: continue rrois.append([bid, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]]) # 将标注的rroi写入 labels.append(gt_txt) gt_good +=1 # 11. debug显示标注的区域 if debug: rois = torch.tensor(rrois).to(torch.float).cuda() pooled_height = 44 maxratio = rois[:,4] / rois[:,3] maxratio = maxratio.max().item() pooled_width = math.ceil(pooled_height * maxratio) roipool = _RRoiAlign(pooled_height, pooled_width, 1.0) # 声明类 pooled_feat = roipool(im_data, rois.view(-1, 6)) for i in range(pooled_feat.shape[0]): x_d = pooled_feat.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw) cv2.imshow('img', img) cv2.waitKey(100) # 12. 进行ctc label的转换 // 以上都是为了求rrois和labels // 这里是求的一个batch内的rroi if len(labels) > 32: labels = labels[:32] rrois = rrois[:32] text, label_length = converter.encode(labels) # 13.rroi_align, 特征前向传播,并求ctcloss rois = torch.tensor(rrois).to(torch.float).cuda() pooled_height = 32 maxratio = rois[:, 4] / rois[:, 3] maxratio = maxratio.max().item() pooled_width = math.ceil(pooled_height * maxratio) roipool = _RRoiAlign(pooled_height, pooled_width, 1.0) # 声明类 pooled_feat = roipool(im_data, rois.view(-1, 6)) # 13.1 显示所有的crop区域 alldebug = 0 if alldebug: for i in range(pooled_feat.shape[0]): x_d = pooled_feat.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw) for j in range(im_data.size(0)): img = im_data[j].cpu().numpy().transpose(1,2,0) img = (img + 1) * 128 img = np.asarray(img, dtype=np.uint8) img = img[:, :, ::-1] cv2.imshow('img%d'%j, img) cv2.imwrite('./data/tshow/img%d.jpg' % j, img) cv2.waitKey(100) # ocr_features = net.forward_features(pooled_feat) # preds = net.forward_ocr(ocr_features) # preds = preds.permute(2, 0, 1) preds = net.ocr_forward(pooled_feat) preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1))) # 求ctc loss loss_ocr = ctc_loss(preds, text, preds_size, label_length) / preds.size(1) # 求一个平均 return loss_ocr, gt_good , gt_proc
def align_ocr(net, converter, im_data, boxo, features, debug=0): # 将ocr区域的图像处理后进行识别 boxr = boxo[0:8].reshape(-1, 2) # 1. 准备rroi的数据 center = (boxr[0, :] + boxr[1, :] + boxr[2, :] + boxr[3, :]) / 4 dw = boxr[2, :] - boxr[1, :] dh = boxr[1, :] - boxr[0, :] w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) angle = math.atan2((boxr[2][1] - boxr[1][1]), boxr[2][0] - boxr[1][0]) angle = -angle / 3.1415926535 * 180 rroi = [0, int(center[0]), int(center[1]), h, w, angle] target_h = 11 scale = target_h / max(1, h) target_gw = int(w * scale) + target_h target_gw = max(2, target_gw // 32) * 32 rroialign = _RRoiAlign(target_h, target_gw, 1.0 / 4) rois = torch.tensor(rroi).to(torch.float).cuda() # # 2. 对im_data进行rroi_align操作 # x = rroialign(im_data, rois.view(-1, 6)) if debug: for i in range(x.shape[0]): x_d = x.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw) img = im_data[i].cpu().numpy().transpose(1, 2, 0) img = (img + 1) * 128 img = np.asarray(img, dtype=np.uint8) img = img[:, :, ::-1] cv2.imshow('img%d' % i, img) cv2.waitKey(100) x = rroialign(features[1], rois.view(-1, 6)) # 采用同样的特征 # features = net.forward_features(x) labels_pred = net.forward_ocr(x) # labels_pred = net.ocr_forward(x) # labels_pred = labels_pred.permute(0,2,1) _, labels_pred = labels_pred.max(1) labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([labels_pred.size(0)])) sim_preds = converter.decode(labels_pred.data, preds_size.data, raw=False) # ctc_f = labels_pred.data.cpu().numpy() # ctc_f = ctc_f.swapaxes(1, 2) # labels = ctc_f.argmax(2) # ind = np.unravel_index(labels, ctc_f.shape) # conf = np.mean( np.exp(ctc_f[ind]) ) # det_text, conf2, dec_s, splits = print_seq_ext(labels[0, :], codec) conf2 = 0.9 dec_s = 1 return sim_preds, conf2, dec_s
continue rrois.append([bid, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]]) # 将标注的rroi写入 labels.append(gt_txt) gt_good +=1 # 11. debug显示标注的区域 if debug: rois = torch.tensor(rrois).to(torch.float).cuda() pooled_height = 44 maxratio = rois[:,4] / rois[:,3] maxratio = maxratio.max().item() pooled_width = math.ceil(pooled_height * maxratio) roipool = _RRoiAlign(pooled_height, pooled_width, 1.0) # 声明类 pooled_feat = roipool(im_data, rois.view(-1, 6)) for i in range(pooled_feat.shape[0]): x_d = pooled_feat.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
def process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training): num_gt = len(gtso) rrois = [] labels = [] for kk in range(num_gt): gts = gtso[kk] lbs = lbso[kk] if len(gts) != 0: gt = np.asarray(gts) center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] + gt[:, 3, :]) / 4 # 求中心点 dw = gt[:, 2, :] - gt[:, 1, :] dh = gt[:, 1, :] - gt[:, 0, :] poww = pow(dw, 2) powh = pow(dh, 2) w = np.sqrt(poww[:, 0] + poww[:, 1]) h = np.sqrt(powh[:, 0] + powh[:, 1]) + random.randint(-2, 2) angle_gt = (np.arctan2( (gt[:, 2, 1] - gt[:, 1, 1]), gt[:, 2, 0] - gt[:, 1, 0]) + np.arctan2((gt[:, 3, 1] - gt[:, 0, 1]), gt[:, 3, 0] - gt[:, 0, 0])) / 2 # 求角度 angle_gt = -angle_gt / 3.1415926535 * 180 # 需要加个负号 # 10. 对每个rroi进行判断是否用于训练 for gt_id in range(0, len(gts)): gt_txt = lbs[gt_id] # 文字判断 if gt_txt.startswith('##'): continue gt = gts[gt_id] # 标注信息判断 if gt[:, 0].max() > im_data.size( 3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0: continue rrois.append([ kk, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id] ]) # 将标注的rroi写入 labels.append(gt_txt) text, label_length = converter.encode(labels) # 13.rroi_align, 特征前向传播,并求ctcloss rois = torch.tensor(rrois).to(torch.float).cuda() pooled_height = 32 maxratio = rois[:, 4] / rois[:, 3] maxratio = maxratio.max().item() pooled_width = math.ceil(pooled_height * maxratio) roipool = _RRoiAlign(pooled_height, pooled_width, 1.0) # 声明类 pooled_feat = roipool(im_data, rois.view(-1, 6)) # 13.1 显示所有的crop区域 alldebug = 0 if alldebug: for i in range(pooled_feat.shape[0]): x_d = pooled_feat.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw) # cv2.imwrite('./data/tshow/%s.jpg' % labels[i], x_data_draw) for j in range(im_data.size(0)): img = im_data[j].cpu().numpy().transpose(1, 2, 0) img = (img + 1) * 128 img = np.asarray(img, dtype=np.uint8) img = img[:, :, ::-1] cv2.imshow('img%d' % j, img) cv2.imwrite('./data/tshow/img%d.jpg' % j, img) cv2.waitKey(100) if training: preds = net.ocr_forward(pooled_feat) preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1))) # 求ctc loss res = ctc_loss(preds, text, preds_size, label_length) / preds.size( 1) # 求一个平均 else: labels_pred = net.ocr_forward(pooled_feat) _, labels_pred = labels_pred.max(2) labels_pred = labels_pred.contiguous().view(-1) # labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([labels_pred.size(0)])) res = converter.decode(labels_pred.data, preds_size.data, raw=False) res = (res, labels) return res
def process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training): num_gt = len(gtso) rrois = [] labels = [] for kk in range(num_gt): gts = gtso[kk] lbs = lbso[kk] if len(gts) != 0: gt = np.asarray(gts) center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] + gt[:, 3, :]) / 4 # 求中心点 dw = gt[:, 2, :] - gt[:, 1, :] dh = gt[:, 1, :] - gt[:, 0, :] poww = pow(dw, 2) powh = pow(dh, 2) w = np.sqrt(poww[:, 0] + poww[:,1]) h = np.sqrt(powh[:,0] + powh[:,1]) + random.randint(-2, 2) angle_gt = ( np.arctan2((gt[:,2,1] - gt[:,1,1]), gt[:,2,0] - gt[:,1,0]) + np.arctan2((gt[:,3,1] - gt[:,0,1]), gt[:,3,0] - gt[:,0,0]) ) / 2 # 求角度 angle_gt = -angle_gt / 3.1415926535 * 180 # 需要加个负号 # 10. 对每个rroi进行判断是否用于训练 for gt_id in range(0, len(gts)): gt_txt = lbs[gt_id] # 文字判断 if gt_txt.startswith('##'): continue gt = gts[gt_id] # 标注信息判断 if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0: continue rrois.append([kk, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]]) # 将标注的rroi写入 labels.append(gt_txt) text, label_length = converter.encode(labels) # 13.rroi_align, 特征前向传播,并求ctcloss rois = torch.tensor(rrois).to(torch.float).cuda() pooled_height = 32 maxratio = rois[:, 4] / rois[:, 3] maxratio = maxratio.max().item() pooled_width = math.ceil(pooled_height * maxratio) roipool = _RRoiAlign(pooled_height, pooled_width, 1.0) # 声明类 pooled_feat = roipool(im_data, rois.view(-1, 6)) # 13.1 显示所有的crop区域 alldebug = 0 if alldebug: for i in range(pooled_feat.shape[0]): x_d = pooled_feat.data.cpu().numpy()[i] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('crop %d' % i, x_data_draw) cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw) # cv2.imwrite('./data/tshow/%s.jpg' % labels[i], x_data_draw) for j in range(im_data.size(0)): img = im_data[j].cpu().numpy().transpose(1,2,0) img = (img + 1) * 128 img = np.asarray(img, dtype=np.uint8) img = img[:, :, ::-1] cv2.imshow('img%d'%j, img) cv2.imwrite('./data/tshow/img%d.jpg' % j, img) cv2.waitKey(100) if training: preds = net.ocr_forward(pooled_feat) preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1))) # 求ctc loss res = ctc_loss(preds, text, preds_size, label_length) / preds.size(1) # 求一个平均 else: labels_pred = net.ocr_forward(pooled_feat) _, labels_pred = labels_pred.max(2) labels_pred = labels_pred.contiguous().view(-1) # labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([labels_pred.size(0)])) res = converter.decode(labels_pred.data, preds_size.data, raw=False) res = (res, labels) return res class ImgDataset(Dataset): def __init__(self, root=None, csv_root=None, transform=None, target_transform=None): self.root = root with open(csv_root) as f: self.data = f.readlines() self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.data) def __getitem__(self, idx): per_label = self.data[idx].rstrip().split('\t') imgpath = os.path.join(self.root, per_label[0]) srcimg = cv2.imread(imgpath) img = srcimg[:, :, ::-1].copy() if self.transform: img = self.transform(img) # img = torch.tensor(img, dtype=torch.float) img = torch.from_numpy(img) img = img.permute(2,0,1) img = img.float() temp = [[int(x) for x in per_label[2:6]]] roi = [] for i in range(len(temp)): temp1 = np.asarray([[temp[i][0], temp[i][3]], [temp[i][0],temp[i][1]], [temp[i][2],temp[i][1]], [temp[i][2],temp[i][3]]]) roi.append(temp1) # for debug show # cv2.rectangle(srcimg, (temp1[1][0], temp1[1][1]), (temp1[3][0], temp1[3][1]), (255, 0, 0), thickness=2) # # temp1 = temp1.reshape(-1,1,2) # # cv2.polylines(srcimg,[temp1],False,(0,255,255), thickness=3) # plt.imshow(srcimg) # plt.show() text = [per_label[1].lstrip(), per_label[6].lstrip()] return img, roi, text # gt_box的标注信息为x1,y1,x2,y2, 返回一个名字 class ImgDataset2(Dataset): def __init__(self, root=None, csv_root=None, transform=None, target_transform=None): self.root = root with open(csv_root) as f: self.data = f.readlines() self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.data) def __getitem__(self, idx): per_label = self.data[idx].rstrip().split('\t') imgpath = os.path.join(self.root, per_label[0]) srcimg = cv2.imread(imgpath) img = srcimg[:, :, ::-1].copy() if self.transform: img = self.transform(img) # img = torch.tensor(img, dtype=torch.float) img = torch.from_numpy(img) img = img.permute(2, 0, 1) img = img.float() temp = [[int(x) for x in per_label[2:6]], [int(x) for x in per_label[7:11]]] roi = [] for i in range(len(temp)): temp1 = np.asarray([[temp[i][0], temp[i][3]], [temp[i][0], temp[i][1]], [temp[i][2], temp[i][1]], [temp[i][2], temp[i][3]]]) roi.append(temp1) # cv2.rectangle(srcimg, (temp[0][0], temp[0][1]), (temp[0][2], temp[0][3]), (255, 0, 0), thickness=2) # temp1 = temp1.reshape(-1,1,2) # cv2.polylines(srcimg,[temp1],False,(0,255,255), thickness=3) # plt.imshow(srcimg) # plt.show() text = [per_label[1].lstrip(), per_label[6].lstrip()] return img, roi, text # gt_box的标注信息为x1,y1,x2,y2, 返回一个名字 def own_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) img = [] gt_boxes = [] texts = [] for per_batch in batch: img.append(per_batch[0]) gt_boxes.append(per_batch[1]) texts.append(per_batch[2]) return torch.stack(img, 0), gt_boxes, texts class E2Edataset(Dataset): def __init__(self, train_list, input_size=512): super(E2Edataset, self).__init__() self.image_list = np.array(get_images(train_list)) print('{} training images in {}'.format(self.image_list.shape[0], train_list)) self.transform = transforms.Compose([ transforms.ColorJitter(.3,.3,.3,.3), transforms.RandomGrayscale(p=0.1) ]) def __len__(self): return len(self.image_list) def __getitem__(self, index): im_name = self.image_list[index] im = cv2.imread(im_name) # 图片 txt_fn = im_name.replace(os.path.basename(im_name).split('.')[1], 'txt') base_name = os.path.basename(txt_fn) txt_fn_gt = '{0}/gt_{1}'.format(os.path.dirname(im_name), base_name) # 载入标注信息 text_polys, text_tags, labels_txt = load_gt_annoataion(txt_fn_gt, txt_fn_gt.find('/icdar-2015-Ch4/') != -1) pim = PIL.Image.fromarray(np.uint8(im)) if self.transform: pim = self.transform(pim) im = np.array(pim) text_polys, text_tags, labels_txt = load_gt_annoataion(txt_fn_gt, txt_fn_gt.find('/icdar-2015-Ch4/') != -1) new_h, new_w, _ = im.shape score_map, geo_map, training_mask, gt_idx, gt_out, labels_out = generate_rbox(im, (new_h, new_w), text_polys, text_tags, labels_txt, vis=False) im = np.asarray(im, dtype=np.float) im /= 128 im -= 1 im = torch.from_numpy(im).permute(2,0,1) return im, score_map, geo_map, training_mask, gt_idx, gt_out, labels_txt def E2Ecollate(batch): img = [] gt_boxes = [] texts = [] for per_batch in batch: img.append(per_batch[0]) gt_boxes.append(per_batch[5]) texts.append(per_batch[6]) return torch.stack(img, 0), gt_boxes, texts if __name__ == '__main__': llist = './data/ICDAR2015.txt' data = E2Edataset(train_list=llist) E2Edataloader = torch.utils.data.DataLoader(data, batch_size=2, shuffle=False, collate_fn=E2Ecollate) for index, data in enumerate(E2Edataloader): im = data