Пример #1
0
def process_boxes(images, im_data, iou_pred, roi_pred, angle_pred, score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug = False):
  '''iou_pred:类别  roi_pred:上,下,左,右  angle_pred:预测的角度  score_maps:目标分类label'''
  gt_good = 0
  gt_proc = 0

  # 0. 对预测的rroi进行筛选,选出满足要求的rroi,求误差//对一个batch进行操作
  rrois = []
  labels = []
  for bid in range(iou_pred.size(0)):
    
    gts = gtso[bid]               # 文字区域
    lbs = lbso[bid]               # 文字
    
    gt_proc = 0
    gt_good = 0
    
    gts_count = {}                # 每个rroi区域的计数,避免一个rroi多次用于训练

    iou_pred_np = iou_pred[bid].data.cpu().numpy()
    iou_map = score_maps[bid]
    to_walk = iou_pred_np.squeeze(0) * iou_map * (iou_pred_np.squeeze(0) > 0.5)           # 预测为文本区域且实际也为文本的索引点
    
    roi_p_bid = roi_pred[bid].data.cpu().numpy()      # 预测出来的roi
    gt_idx = gt_idxs[bid]                             # 第i个文字区域的label
    
    if debug:
      img = images[bid]
      img += 1
      img *= 128
      img = np.asarray(img, dtype=np.uint8)
    
    xy_text = np.argwhere(to_walk > 0)                          # 返回的顺序为:(行,列)
    random.shuffle(xy_text)
    xy_text = xy_text[0:min(xy_text.shape[0], 100)]             # 选出100个点
    
    # 对预测的点进行crop
    for i in range(0, xy_text.shape[0]):
      if opts.geo_type == 1:
        break
      pos = xy_text[i, :]
      
      gt_id = gt_idx[pos[0], pos[1]]                # 每个点对应的目标文字区域label
      
      if not gt_id in gts_count:
        gts_count[gt_id] = 0
      
      # 1. 一个文字区域最多用2次
      if gts_count[gt_id] > 2:                      # 一个目标文字区域最多可以用几次
        continue

      # 2. 文字是‘##’
      gt = gts[gt_id]                               # 当前点预测区域对应的label区域,和label文本
      gt_txt = lbs[gt_id]
      if gt_txt.startswith('##'):
        continue

      # 3. 目标文字区域的高度
      dhgt =  gt[1] - gt[0]                         # 
      h_gt = math.sqrt(dhgt[0] * dhgt[0] + dhgt[1] * dhgt[1])       # 标注label短边值
      if h_gt < 10:
        continue

      # 4. 标注的区域超出了图像范围
      if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(3):
        continue 
      
      # 5. 预测角度和真实角度相差太大
      angle_sin = angle_pred[bid, 0, pos[0], pos[1]] 
      angle_cos = angle_pred[bid, 1, pos[0], pos[1]]
      angle = math.atan2(angle_sin, angle_cos)            # 预测的角度和真实的角度
      angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2
      if math.fabs(angle_gt - angle) > math.pi / 16:      # 预测角度和真实角度相差11.25度以上,原始为16
        continue

      # 6. 求倾斜4边形4条边的中点——4个角点
      offset = roi_p_bid[:, pos[0], pos[1]]               # 得到当前点的rroi,对应的含义为:上下左右
      posp = pos + 0.25                                   # 顺序为h,w——y,x
      pos_g = np.array([(posp[1] - offset[0] * math.sin(angle)) * 4, (posp[0] - offset[0] * math.cos(angle)) * 4 ])         # 求出的是x,y。xy_test,返回的是(行,列),转换到图像坐标系是(y,x)
      pos_g2 = np.array([ (posp[1] + offset[1] * math.sin(angle)) * 4, (posp[0] + offset[1] * math.cos(angle)) * 4 ])
      pos_r = np.array([(posp[1] - offset[2] * math.cos(angle)) * 4, (posp[0] - offset[2] * math.sin(angle)) * 4 ])
      pos_r2 = np.array([(posp[1] + offset[3] * math.cos(angle)) * 4, (posp[0] + offset[3] * math.sin(angle)) * 4 ])
      
      center = (pos_g + pos_g2 + pos_r + pos_r2) / 2 - [4*pos[1], 4*pos[0]]    
      #center = (pos_g + pos_g2 + pos_r + pos_r2) / 4   # 求中心
      dw = pos_r - pos_r2                               # 长边
      dh =  pos_g - pos_g2                              # 短边
      w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])      # 长边值
      h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])      # 短边值
    
      rect = ( (center[0], center[1]), (w, h), angle * 180 / math.pi )        # 预测的矩形,网络输出的是sin和cos,angle为弧度
      pts = cv2.boxPoints(rect)                                               # 4边形4个角点的值
      
      # 7. 求倾斜4边形和目标矩形的IOU// 这里是用矩形的方式求的
      pred_bbox = cv2.boundingRect(pts)                                       # 倾斜4边形的外接矩形,[x,y,w,h]
      pred_bbox = [pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3]]    # 返回的是
      pred_bbox[2] += pred_bbox[0]
      pred_bbox[3] += pred_bbox[1]
      gt_bbox = [gt[:, 0].min(), gt[:, 1].min(), gt[:, 0].max(), gt[:, 1].max()]    # 目标box的外接矩形
      
      inter = intersect(pred_bbox, gt_bbox)             # 交集
      uni = union(pred_bbox, gt_bbox)                   # 并集
      ratio = area(inter) / float(area(uni))            # 求两个矩形的交并比
      
      if ratio < 0.9:                                  # 交并比小于0.9则舍弃
        continue
      hratio = min(h, h_gt) / max(h, h_gt)              # 高度相差太多
      if hratio < 0.5:
        continue
      
      # 8. 将rroi按rroi_align的要求进行整理
      angle = -angle / 3.1415926535 * 180
      rrois.append([bid, center[0], center[1], h, w, angle])   # 将多个rroi添加在一起
      labels.append(gt_txt)
      gts_count[gt_id] += 1
      gt_proc += 1

    # 8.1. for debug: 自己读入图片进行测试// 以上为对预测的rroi进行筛选
    # img = cv2.imread('./rroi_align/data/timg.jpeg')
    # gts = [[[206,111],[199,95],[349,60],[355,80]]]
    # im_data = torch.from_numpy(img).unsqueeze(0).permute(0, 3, 1, 2)  # 显示测试
    # im_data = im_data.to(torch.float).cuda()

    # 9. 为了引导误差的收敛方向,每张图片都将标注的rroi crop出来进行训练
    if len(gts) != 0:
      gt = np.asarray(gts)
      center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] + gt[:, 3, :]) / 4        # 求中心点
      dw = gt[:, 2, :] - gt[:, 1, :]
      dh =  gt[:, 1, :] - gt[:, 0, :] 
      poww = pow(dw, 2)
      powh = pow(dh, 2)
      w = np.sqrt(poww[:, 0] + poww[:,1])
      h = np.sqrt(powh[:,0] + powh[:,1])  + random.randint(-2, 2)
      angle_gt = ( np.arctan2((gt[:,2,1] - gt[:,1,1]), gt[:,2,0] - gt[:,1,0]) + np.arctan2((gt[:,3,1] - gt[:,0,1]), gt[:,3,0] - gt[:,0,0]) ) / 2        # 求角度
      angle_gt = -angle_gt / 3.1415926535 * 180                                   # 需要加个负号

      # 10. 对每个rroi进行判断是否用于训练
      for gt_id in range(0, len(gts)):
        
        gt_txt = lbs[gt_id]                       # 文字判断
        if gt_txt.startswith('##'):
          continue
        
        gt = gts[gt_id]                           # 标注信息判断
        if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0:
          continue
        
        rrois.append([bid, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]])   # 将标注的rroi写入
        labels.append(gt_txt)
        gt_good +=1
      

    # 11. debug显示标注的区域
    if debug:
      rois = torch.tensor(rrois).to(torch.float).cuda()
      pooled_height = 44
      maxratio = rois[:,4] / rois[:,3]
      maxratio = maxratio.max().item()
      pooled_width = math.ceil(pooled_height * maxratio)

      roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)        # 声明类
      pooled_feat = roipool(im_data, rois.view(-1, 6))

      for i in range(pooled_feat.shape[0]):

        x_d = pooled_feat.data.cpu().numpy()[i]
        x_data_draw = x_d.swapaxes(0, 2)
        x_data_draw = x_data_draw.swapaxes(0, 1)

        x_data_draw += 1
        x_data_draw *= 128
        x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
        x_data_draw = x_data_draw[:, :, ::-1]
        cv2.imshow('crop %d' % i, x_data_draw)
        cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
            
      cv2.imshow('img', img)
      cv2.waitKey(100)


  # 12. 进行ctc label的转换 // 以上都是为了求rrois和labels // 这里是求的一个batch内的rroi
  if len(labels) > 32:
    labels = labels[:32]
    rrois = rrois[:32]
  text, label_length = converter.encode(labels)

  # 13.rroi_align, 特征前向传播,并求ctcloss
  rois = torch.tensor(rrois).to(torch.float).cuda()
  pooled_height = 32
  maxratio = rois[:, 4] / rois[:, 3]
  maxratio = maxratio.max().item()
  pooled_width = math.ceil(pooled_height * maxratio)

  roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)  # 声明类
  pooled_feat = roipool(im_data, rois.view(-1, 6))

  # 13.1 显示所有的crop区域
  alldebug = 0
  if alldebug:
      for i in range(pooled_feat.shape[0]):

        x_d = pooled_feat.data.cpu().numpy()[i]
        x_data_draw = x_d.swapaxes(0, 2)
        x_data_draw = x_data_draw.swapaxes(0, 1)

        x_data_draw += 1
        x_data_draw *= 128
        x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
        x_data_draw = x_data_draw[:, :, ::-1]
        cv2.imshow('crop %d' % i, x_data_draw)
        cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)

      for j in range(im_data.size(0)):
        img = im_data[j].cpu().numpy().transpose(1,2,0)
        img = (img + 1) * 128
        img = np.asarray(img, dtype=np.uint8)
        img = img[:, :, ::-1]
        cv2.imshow('img%d'%j, img)
        cv2.imwrite('./data/tshow/img%d.jpg' % j, img)
      cv2.waitKey(100)
      
  # ocr_features = net.forward_features(pooled_feat)
  # preds = net.forward_ocr(ocr_features)
  # preds = preds.permute(2, 0, 1)

  preds = net.ocr_forward(pooled_feat)

  preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1)))       # 求ctc loss
  loss_ocr = ctc_loss(preds, text, preds_size, label_length) / preds.size(1)    # 求一个平均

  return loss_ocr, gt_good , gt_proc
Пример #2
0
def align_ocr(net, converter, im_data, boxo, features, debug=0):
    # 将ocr区域的图像处理后进行识别
    boxr = boxo[0:8].reshape(-1, 2)

    # 1. 准备rroi的数据
    center = (boxr[0, :] + boxr[1, :] + boxr[2, :] + boxr[3, :]) / 4

    dw = boxr[2, :] - boxr[1, :]
    dh = boxr[1, :] - boxr[0, :]
    w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
    h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])

    angle = math.atan2((boxr[2][1] - boxr[1][1]), boxr[2][0] - boxr[1][0])
    angle = -angle / 3.1415926535 * 180
    rroi = [0, int(center[0]), int(center[1]), h, w, angle]

    target_h = 11
    scale = target_h / max(1, h)
    target_gw = int(w * scale) + target_h
    target_gw = max(2, target_gw // 32) * 32
    rroialign = _RRoiAlign(target_h, target_gw, 1.0 / 4)
    rois = torch.tensor(rroi).to(torch.float).cuda()

    # # 2. 对im_data进行rroi_align操作
    # x = rroialign(im_data, rois.view(-1, 6))

    if debug:
        for i in range(x.shape[0]):

            x_d = x.data.cpu().numpy()[i]
            x_data_draw = x_d.swapaxes(0, 2)
            x_data_draw = x_data_draw.swapaxes(0, 1)

            x_data_draw += 1
            x_data_draw *= 128
            x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
            x_data_draw = x_data_draw[:, :, ::-1]
            cv2.imshow('crop %d' % i, x_data_draw)
            cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
            img = im_data[i].cpu().numpy().transpose(1, 2, 0)
            img = (img + 1) * 128
            img = np.asarray(img, dtype=np.uint8)
            img = img[:, :, ::-1]
            cv2.imshow('img%d' % i, img)
        cv2.waitKey(100)

    x = rroialign(features[1], rois.view(-1, 6))  # 采用同样的特征
    # features = net.forward_features(x)
    labels_pred = net.forward_ocr(x)
    # labels_pred = net.ocr_forward(x)
    # labels_pred = labels_pred.permute(0,2,1)

    _, labels_pred = labels_pred.max(1)
    labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1)
    preds_size = Variable(torch.IntTensor([labels_pred.size(0)]))
    sim_preds = converter.decode(labels_pred.data, preds_size.data, raw=False)

    # ctc_f = labels_pred.data.cpu().numpy()
    # ctc_f = ctc_f.swapaxes(1, 2)

    # labels = ctc_f.argmax(2)

    # ind = np.unravel_index(labels, ctc_f.shape)
    # conf = np.mean( np.exp(ctc_f[ind]) )

    # det_text, conf2, dec_s, splits = print_seq_ext(labels[0, :], codec)
    conf2 = 0.9
    dec_s = 1
    return sim_preds, conf2, dec_s
Пример #3
0
      continue
    
    rrois.append([bid, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]])   # 将标注的rroi写入
    labels.append(gt_txt)
    gt_good +=1
      

    # 11. debug显示标注的区域
    if debug:
      rois = torch.tensor(rrois).to(torch.float).cuda()
      pooled_height = 44
      maxratio = rois[:,4] / rois[:,3]
      maxratio = maxratio.max().item()
      pooled_width = math.ceil(pooled_height * maxratio)

      roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)    # 声明类
      pooled_feat = roipool(im_data, rois.view(-1, 6))

      for i in range(pooled_feat.shape[0]):

    x_d = pooled_feat.data.cpu().numpy()[i]
    x_data_draw = x_d.swapaxes(0, 2)
    x_data_draw = x_data_draw.swapaxes(0, 1)

    x_data_draw += 1
    x_data_draw *= 128
    x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
    x_data_draw = x_data_draw[:, :, ::-1]
    cv2.imshow('crop %d' % i, x_data_draw)
    cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
        
Пример #4
0
def process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training):
    num_gt = len(gtso)
    rrois = []
    labels = []
    for kk in range(num_gt):
        gts = gtso[kk]
        lbs = lbso[kk]
        if len(gts) != 0:
            gt = np.asarray(gts)
            center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] +
                      gt[:, 3, :]) / 4  # 求中心点
            dw = gt[:, 2, :] - gt[:, 1, :]
            dh = gt[:, 1, :] - gt[:, 0, :]
            poww = pow(dw, 2)
            powh = pow(dh, 2)
            w = np.sqrt(poww[:, 0] + poww[:, 1])
            h = np.sqrt(powh[:, 0] + powh[:, 1]) + random.randint(-2, 2)
            angle_gt = (np.arctan2(
                (gt[:, 2, 1] - gt[:, 1, 1]), gt[:, 2, 0] - gt[:, 1, 0]) +
                        np.arctan2((gt[:, 3, 1] - gt[:, 0, 1]),
                                   gt[:, 3, 0] - gt[:, 0, 0])) / 2  # 求角度
            angle_gt = -angle_gt / 3.1415926535 * 180  # 需要加个负号

            # 10. 对每个rroi进行判断是否用于训练
            for gt_id in range(0, len(gts)):

                gt_txt = lbs[gt_id]  # 文字判断
                if gt_txt.startswith('##'):
                    continue

                gt = gts[gt_id]  # 标注信息判断
                if gt[:, 0].max() > im_data.size(
                        3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0:
                    continue

                rrois.append([
                    kk, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id],
                    angle_gt[gt_id]
                ])  # 将标注的rroi写入
                labels.append(gt_txt)

    text, label_length = converter.encode(labels)

    # 13.rroi_align, 特征前向传播,并求ctcloss
    rois = torch.tensor(rrois).to(torch.float).cuda()
    pooled_height = 32
    maxratio = rois[:, 4] / rois[:, 3]
    maxratio = maxratio.max().item()
    pooled_width = math.ceil(pooled_height * maxratio)

    roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)  # 声明类
    pooled_feat = roipool(im_data, rois.view(-1, 6))

    # 13.1 显示所有的crop区域
    alldebug = 0
    if alldebug:
        for i in range(pooled_feat.shape[0]):

            x_d = pooled_feat.data.cpu().numpy()[i]
            x_data_draw = x_d.swapaxes(0, 2)
            x_data_draw = x_data_draw.swapaxes(0, 1)

            x_data_draw += 1
            x_data_draw *= 128
            x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
            x_data_draw = x_data_draw[:, :, ::-1]
            cv2.imshow('crop %d' % i, x_data_draw)
            cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
            # cv2.imwrite('./data/tshow/%s.jpg' % labels[i], x_data_draw)

        for j in range(im_data.size(0)):
            img = im_data[j].cpu().numpy().transpose(1, 2, 0)
            img = (img + 1) * 128
            img = np.asarray(img, dtype=np.uint8)
            img = img[:, :, ::-1]
            cv2.imshow('img%d' % j, img)
            cv2.imwrite('./data/tshow/img%d.jpg' % j, img)
        cv2.waitKey(100)

    if training:
        preds = net.ocr_forward(pooled_feat)

        preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                              preds.size(1)))  # 求ctc loss
        res = ctc_loss(preds, text, preds_size, label_length) / preds.size(
            1)  # 求一个平均
    else:
        labels_pred = net.ocr_forward(pooled_feat)

        _, labels_pred = labels_pred.max(2)
        labels_pred = labels_pred.contiguous().view(-1)
        # labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([labels_pred.size(0)]))
        res = converter.decode(labels_pred.data, preds_size.data, raw=False)
        res = (res, labels)
    return res
Пример #5
0
def process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training):
    num_gt = len(gtso)
    rrois = []
    labels = []
    for kk in range(num_gt):
    gts = gtso[kk]
    lbs = lbso[kk]
    if len(gts) != 0:
        gt = np.asarray(gts)
        center = (gt[:, 0, :] + gt[:, 1, :] + gt[:, 2, :] + gt[:, 3, :]) / 4    # 求中心点
        dw = gt[:, 2, :] - gt[:, 1, :]
        dh =  gt[:, 1, :] - gt[:, 0, :] 
        poww = pow(dw, 2)
        powh = pow(dh, 2)
        w = np.sqrt(poww[:, 0] + poww[:,1])
        h = np.sqrt(powh[:,0] + powh[:,1])  + random.randint(-2, 2)
        angle_gt = ( np.arctan2((gt[:,2,1] - gt[:,1,1]), gt[:,2,0] - gt[:,1,0]) + np.arctan2((gt[:,3,1] - gt[:,0,1]), gt[:,3,0] - gt[:,0,0]) ) / 2    # 求角度
        angle_gt = -angle_gt / 3.1415926535 * 180                   # 需要加个负号

        # 10. 对每个rroi进行判断是否用于训练
        for gt_id in range(0, len(gts)):
        
        gt_txt = lbs[gt_id]               # 文字判断
        if gt_txt.startswith('##'):
            continue
        
        gt = gts[gt_id]               # 标注信息判断
        if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(2) or gt.min() < 0:
            continue
        
        rrois.append([kk, center[gt_id][0], center[gt_id][1], h[gt_id], w[gt_id], angle_gt[gt_id]])   # 将标注的rroi写入
        labels.append(gt_txt)

    text, label_length = converter.encode(labels)

    # 13.rroi_align, 特征前向传播,并求ctcloss
    rois = torch.tensor(rrois).to(torch.float).cuda()
    pooled_height = 32
    maxratio = rois[:, 4] / rois[:, 3]
    maxratio = maxratio.max().item()
    pooled_width = math.ceil(pooled_height * maxratio)

    roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)  # 声明类
    pooled_feat = roipool(im_data, rois.view(-1, 6))

    # 13.1 显示所有的crop区域
    alldebug = 0
    if alldebug:
    for i in range(pooled_feat.shape[0]):

        x_d = pooled_feat.data.cpu().numpy()[i]
        x_data_draw = x_d.swapaxes(0, 2)
        x_data_draw = x_data_draw.swapaxes(0, 1)

        x_data_draw += 1
        x_data_draw *= 128
        x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
        x_data_draw = x_data_draw[:, :, ::-1]
        cv2.imshow('crop %d' % i, x_data_draw)
        cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
        # cv2.imwrite('./data/tshow/%s.jpg' % labels[i], x_data_draw)

    for j in range(im_data.size(0)):
        img = im_data[j].cpu().numpy().transpose(1,2,0)
        img = (img + 1) * 128
        img = np.asarray(img, dtype=np.uint8)
        img = img[:, :, ::-1]
        cv2.imshow('img%d'%j, img)
        cv2.imwrite('./data/tshow/img%d.jpg' % j, img)
    cv2.waitKey(100)

    if training:
    preds = net.ocr_forward(pooled_feat)

    preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1)))       # 求ctc loss
    res = ctc_loss(preds, text, preds_size, label_length) / preds.size(1)    # 求一个平均
    else:
    labels_pred = net.ocr_forward(pooled_feat)

    _, labels_pred = labels_pred.max(2)
    labels_pred = labels_pred.contiguous().view(-1)
    # labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1)
    preds_size = Variable(torch.IntTensor([labels_pred.size(0)]))
    res = converter.decode(labels_pred.data, preds_size.data, raw=False)
    res = (res, labels)
    return res


class ImgDataset(Dataset):
    def __init__(self, root=None, csv_root=None, transform=None, target_transform=None):
    self.root = root
    with open(csv_root) as f:
        self.data = f.readlines()
    self.transform = transform
    self.target_transform = target_transform

    def __len__(self):
    return len(self.data)

    def __getitem__(self, idx):
    per_label = self.data[idx].rstrip().split('\t')
    imgpath = os.path.join(self.root, per_label[0])
    srcimg = cv2.imread(imgpath)
    img = srcimg[:, :, ::-1].copy()

    if self.transform:
        img = self.transform(img)
        # img = torch.tensor(img, dtype=torch.float)
    img = torch.from_numpy(img)
    img = img.permute(2,0,1)
    img = img.float()
    
    temp = [[int(x) for x in per_label[2:6]]]

    roi = []
    for i in range(len(temp)):
        temp1 = np.asarray([[temp[i][0], temp[i][3]], [temp[i][0],temp[i][1]], [temp[i][2],temp[i][1]], [temp[i][2],temp[i][3]]])
        roi.append(temp1)

    # for debug show
    #     cv2.rectangle(srcimg, (temp1[1][0], temp1[1][1]), (temp1[3][0], temp1[3][1]), (255, 0, 0), thickness=2)
    # #     temp1 = temp1.reshape(-1,1,2)
    # #     cv2.polylines(srcimg,[temp1],False,(0,255,255), thickness=3)
    # plt.imshow(srcimg)
    # plt.show()

    text = [per_label[1].lstrip(), per_label[6].lstrip()]


    return img, roi, text      # gt_box的标注信息为x1,y1,x2,y2, 返回一个名字


class ImgDataset2(Dataset):
    def __init__(self, root=None, csv_root=None, transform=None, target_transform=None):
    self.root = root
    with open(csv_root) as f:
        self.data = f.readlines()
    self.transform = transform
    self.target_transform = target_transform

    def __len__(self):
    return len(self.data)

    def __getitem__(self, idx):
    per_label = self.data[idx].rstrip().split('\t')
    imgpath = os.path.join(self.root, per_label[0])
    srcimg = cv2.imread(imgpath)
    img = srcimg[:, :, ::-1].copy()

    if self.transform:
        img = self.transform(img)
        # img = torch.tensor(img, dtype=torch.float)
    img = torch.from_numpy(img)
    img = img.permute(2, 0, 1)
    img = img.float()

    temp = [[int(x) for x in per_label[2:6]],
        [int(x) for x in per_label[7:11]]]

    roi = []
    for i in range(len(temp)):
        temp1 = np.asarray([[temp[i][0], temp[i][3]], [temp[i][0], temp[i][1]], [temp[i][2], temp[i][1]],
                [temp[i][2], temp[i][3]]])
        roi.append(temp1)

    # cv2.rectangle(srcimg, (temp[0][0], temp[0][1]), (temp[0][2], temp[0][3]), (255, 0, 0), thickness=2)
    #     temp1 = temp1.reshape(-1,1,2)
    #     cv2.polylines(srcimg,[temp1],False,(0,255,255), thickness=3)
    # plt.imshow(srcimg)
    # plt.show()

    text = [per_label[1].lstrip(), per_label[6].lstrip()]

    return img, roi, text  # gt_box的标注信息为x1,y1,x2,y2, 返回一个名字


def own_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    img = []
    gt_boxes = []
    texts = []
    for per_batch in batch:
    img.append(per_batch[0])
    gt_boxes.append(per_batch[1])
    texts.append(per_batch[2])

    return torch.stack(img, 0), gt_boxes, texts


class E2Edataset(Dataset):
    def __init__(self, train_list, input_size=512):
    super(E2Edataset, self).__init__()
    self.image_list = np.array(get_images(train_list))

    print('{} training images in {}'.format(self.image_list.shape[0], train_list))

    self.transform = transforms.Compose([
            transforms.ColorJitter(.3,.3,.3,.3),
            transforms.RandomGrayscale(p=0.1)  ])

    def __len__(self):
    return len(self.image_list)

    def __getitem__(self, index):
    im_name = self.image_list[index]

    im = cv2.imread(im_name)        # 图片

    txt_fn = im_name.replace(os.path.basename(im_name).split('.')[1], 'txt')
    base_name = os.path.basename(txt_fn)
    txt_fn_gt = '{0}/gt_{1}'.format(os.path.dirname(im_name), base_name)

    # 载入标注信息
    text_polys, text_tags, labels_txt = load_gt_annoataion(txt_fn_gt, txt_fn_gt.find('/icdar-2015-Ch4/') != -1)

    pim = PIL.Image.fromarray(np.uint8(im))
    if self.transform:
        pim = self.transform(pim)
    im = np.array(pim)

    text_polys, text_tags, labels_txt = load_gt_annoataion(txt_fn_gt, txt_fn_gt.find('/icdar-2015-Ch4/') != -1)

    new_h, new_w, _ = im.shape
    score_map, geo_map, training_mask, gt_idx, gt_out, labels_out = generate_rbox(im, (new_h, new_w), text_polys, text_tags, labels_txt, vis=False)

    im = np.asarray(im, dtype=np.float)
    im /= 128
    im -= 1
    im = torch.from_numpy(im).permute(2,0,1)

    return im, score_map, geo_map, training_mask, gt_idx, gt_out, labels_txt


def E2Ecollate(batch):
    img = []
    gt_boxes = []
    texts = []
    for per_batch in batch:
    img.append(per_batch[0])
    gt_boxes.append(per_batch[5])
    texts.append(per_batch[6])

    return torch.stack(img, 0), gt_boxes, texts


if __name__ == '__main__':
    llist = './data/ICDAR2015.txt'

    data = E2Edataset(train_list=llist)

    E2Edataloader = torch.utils.data.DataLoader(data, batch_size=2, shuffle=False, collate_fn=E2Ecollate)

    for index, data in enumerate(E2Edataloader):
    im = data