Пример #1
0
 def load_image(self, image_path):
     img = utils.cv2read(image_path).astype('float32')
     # img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32')
     original_shape = img.shape[:2]
     img = self.resize_image(img)
     img -= self.RGB_MEAN
     img /= 255.
     img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0)
     return img, original_shape
Пример #2
0
 def format_output(self, batch, output):
     batch_boxes, batch_scores = output
     crop_img_path = os.path.join(self.args['result_dir'], 'crop')
     os.makedirs(crop_img_path, exist_ok=True)
     for index in range(batch['image'].size(0)):
         original_shape = batch['shape'][index]
         filename = batch['filename'][index]
         raw_img = utils.cv2read(filename).astype('float32')
         # raw_img = cv2.imread(filename, cv2.IMREAD_COLOR)
         result_file_name = 'res_' + os.path.splitext(
             os.path.basename(filename))[0] + '.txt'
         result_file_path = os.path.join(self.args['result_dir'],
                                         result_file_name)
         boxes = batch_boxes[index]
         scores = batch_scores[index]
         if self.args['polygon']:
             with open(result_file_path, 'wt') as res:
                 for i, box in enumerate(boxes):
                     box = np.array(box).reshape(-1).tolist()
                     result = ",".join([str(int(x)) for x in box])
                     score = scores[i]
                     res.write(result + ',' + str(score) + "\n")
         else:
             if self.args['sort_boxes']:
                 new_boxes = []
                 # new_scores = []
                 for i in range(boxes.shape[0]):
                     score = scores[i]
                     if score < self.args['box_thresh']:
                         continue
                     new_boxes.append(boxes[i, :, :])
                     # new_scores.append(score)
                 if len(new_boxes) == 0:
                     return
                 recs = [
                     utils.trans_poly_to_rec(idx, box)
                     for idx, box in enumerate(new_boxes)
                 ]
                 cluster_rec_ids = utils.cluster_recs_with_width(
                     recs,
                     new_boxes,
                     type='AgglomerativeClustering_ward',
                     n_clusters=2)
                 cluster_recs = []
                 for k in cluster_rec_ids.keys():
                     box_ids = cluster_rec_ids[k]
                     cluster_recs.append(
                         [recs[box_id] for box_id in box_ids])
                 cluster_recs = sorted(cluster_recs,
                                       key=utils.width_sort,
                                       reverse=False)
                 bigger_idx = [b.idx for b in cluster_recs[-1]]
                 '''
                 cluster_rec_ids = utils.cluster_recs_with_lr(recs, type='DBSCAN')
                 cluster_recs = []
                 for k in cluster_rec_ids.keys():
                     box_ids = cluster_rec_ids[k]
                     cluster_recs.append([recs[box_id] for box_id in box_ids])
                 classified_recs = sorted(cluster_recs, key=utils.list_sort, reverse=True)
                 classified_recs = [sorted(l, key=utils.box_sort, reverse=False) for l in classified_recs]
                 output_recs = utils.read_out(classified_recs, recs, cover_threshold=0.3, bigger_idx=bigger_idx)
                 '''
                 output_recs = utils.read_out_2(recs, bigger_idx)
                 output_idxs = []
                 for crop_idx, rec in enumerate(output_recs):
                     crop_path = os.path.join(
                         crop_img_path,
                         os.path.splitext(os.path.basename(filename))[0] +
                         '_' + str(crop_idx) + '.jpg')
                     crop_l = max(0, rec.l - 5)
                     crop_r = min(original_shape[1], rec.r + 5)
                     crop_u = max(0, rec.u - 5)
                     crop_d = min(original_shape[0], rec.d + 5)
                     cv2.imwrite(crop_path, raw_img[crop_u:crop_d,
                                                    crop_l:crop_r, :])
                     output_idxs.append(rec.idx)
                 # output_idxs = [i.idx for i in output_idxs]
                 with open(result_file_path, 'w', encoding='utf-8') as res:
                     for idx in output_idxs:
                         box = new_boxes[idx].reshape(-1).tolist()
                         if idx in bigger_idx:
                             box.append('big')
                         else:
                             box.append('small')
                         box = list(map(str, box))
                         result = ",".join(box)
                         res.write(result + "\n")
             else:
                 with open(result_file_path, 'wt') as res:
                     for i in range(boxes.shape[0]):
                         score = scores[i]
                         if score < self.args['box_thresh']:
                             continue
                         box = boxes[i, :, :].reshape(-1).tolist()
                         result = ",".join([str(int(x)) for x in box])
                         res.write(result + ',' + str(score) + "\n")
Пример #3
0
# import utils.cv2read


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img', type=str, required=True)
    parser.add_argument('--gt', type=str, required=True)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if not os.path.isdir(args.img):
        # img = cv2.imread(args.img)
        img = utils.cv2read(args.img)
        font = cv2.FONT_HERSHEY_SIMPLEX
        with open(args.gt, 'r', encoding='utf-8') as fp:
            for idx, line in enumerate(fp):
                box = line.strip().split(',')
                label = box[-1]
                box = box[:-1]
                box = list(map(int, box))
                box = np.array(box).astype(np.int32).reshape(-1, 2)
                cv2.polylines(img, [box], True, (0, 255, 0), 2)
                if label == 'big':
                    cv2.putText(
                        img, str(idx),
                        (int(np.min(box[:, 0])), int(np.mean(box[:, 1]))),
                        font, 1, (255, 0, 0), 2)
                else: