示例#1
0
def writeToRAWIMAGE_andNMS(data_dir, image_name, category):
    raw_image_name = image_name
    nms_pts = {cat: [] for cat in category}       # bbox的结果
    nms_scores = {cat: [] for cat in category}    # nms对应的scores
    with open(os.path.join(data_dir, 'result', raw_image_name + '.txt'), 'r') as f:
        # 读raw_image_name对应的图片
        pres = f.readlines()
        if pres == "":
            return                                # 为空则返回
        for pre in pres:
            # 逆时针(不改变顺序)
            pre = pre.strip('\n').split(" ")      # 以换行切断,以空格分割
            pre[4], pre[8] = pre[8], pre[4]
            pre[5], pre[9] = pre[9], pre[5]
            nms_pts[pre[0]].append(pre[2:])       # 更新当前类别的bbox结果
            nms_scores[pre[0]].append(pre[1])     # 记录当前类别的置信度
    nms_result = {cat: [] for cat in category}    # NMS后的结果
    for cat in category:
        if cat in nms_pts:
            nms_pts_cat = np.asarray(nms_pts[cat], np.float32).reshape(-1, 4, 2)       # 转换为4个点
            nms_score_cat = np.asarray(nms_scores[cat], np.float32)                    # scores
            nms_item = func_utils.non_maximum_suppression(nms_pts_cat, nms_score_cat)  # NMS操作(阈值可改!!!!!!!!!)
            if nms_item.shape[0] != 0:
                nms_result[cat].extend(nms_item)
    if not os.path.exists(os.path.join(data_dir, 'After_nms_result')):
        os.mkdir(os.path.join(data_dir, 'After_nms_result'))
    with open(os.path.join(data_dir, 'After_nms_result', raw_image_name + '.txt'), 'w+') as f:
        for cat in nms_result.keys():
            for predict in nms_result[cat]:
                # 又转成顺时针(???????????????)
                predict[2], predict[6] = predict[6], predict[2]
                predict[3], predict[7] = predict[7], predict[3]
                location = [str(pre) for pre in predict[:8]]
                f.write(cat + " " + str(predict[-1]) + " " + " ".join(location) + '\n')
示例#2
0
 def test(self, args, down_ratio):
     save_path = 'weights_' + args.dataset     # 检查点位置
     self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
     self.model = self.model.to(self.device)
     self.model.eval()
     dataset_module = self.dataset[args.dataset]
     dsets = dataset_module(data_dir=args.data_dir,
                            phase='test',
                            input_h=args.input_h,
                            input_w=args.input_w,
                            down_ratio=down_ratio)
     data_loader = torch.utils.data.DataLoader(dsets,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
     total_time = []
     LT_dict = getLT(args.data_dir)       # 切片与大图之间的位置映射
     for cnt, data_dict in enumerate(data_loader):
         image = data_dict['image'][0].to(self.device)
         img_id = data_dict['img_id'][0]
         print('processing {}/{} image ...'.format(cnt, len(data_loader)))
         begin_time = time.time()
         with torch.no_grad():
             pr_decs = self.model(image)
         torch.cuda.synchronize(self.device)
         decoded_pts = []
         decoded_scores = []
         predictions = self.decoder.ctdet_decode(pr_decs)
         pts0, scores0 = func_utils.decode_prediction(predictions, dsets, args, img_id, down_ratio)
         decoded_pts.append(pts0)
         decoded_scores.append(scores0)
         # 切片中的nms
         results = {cat: [] for cat in dsets.category}
         for cat in dsets.category:
             if cat == 'background':
                 continue
             pts_cat = []
             scores_cat = []
             for pts0, scores0 in zip(decoded_pts, decoded_scores):
                 pts_cat.extend(pts0[cat])
                 scores_cat.extend(scores0[cat])
             pts_cat = np.asarray(pts_cat, np.float32)
             scores_cat = np.asarray(scores_cat, np.float32)
             if pts_cat.shape[0]:
                 nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat)
                 results[cat].extend(nms_results)
         end_time = time.time()
         total_time.append(end_time - begin_time)
         # 切片bbox映射到大图
         for cat in dsets.category:
             if cat == 'background':
                 continue
             result = results[cat]
             if result != []:
                 result = np.array(result).astype(np.float32)
                 writeToRAWIMAGE(result, cat, args.data_dir, img_id, LT_dict)
             else:
                 # 没有目标写入空的txt
                 raw_image_name, _ = img_id.split('_')
                 with open(os.path.join(args.data_dir, 'result', raw_image_name + '.txt'), 'a+') as f:
                     pass
     # 对大图上的结果进行NMS
     txt_list = glob.glob(os.path.join(args.data_dir, 'result', '*.txt'))    # 查找符合条件的文档
     for txt in txt_list:
         nms_image_name, _ = os.path.basename(txt).split(".")
         nms_and_write.writeToRAWIMAGE_andNMS(args.data_dir, nms_image_name, dsets.category)  # NMS
     total_time = total_time[1:]
     print('avg time is {}'.format(np.mean(total_time)))
     print('FPS is {}'.format(1. / np.mean(total_time)))
示例#3
0
    def test(self, args, down_ratio):
        save_path = 'weights_' + args.dataset
        self.model = self.load_model(self.model,
                                     os.path.join(save_path, args.resume))
        self.model = self.model.to(self.device)
        self.model.eval()

        dataset_module = self.dataset[args.dataset]
        dsets = dataset_module(data_dir=args.data_dir,
                               phase='test',
                               input_h=args.input_h,
                               input_w=args.input_w,
                               down_ratio=down_ratio)
        data_loader = torch.utils.data.DataLoader(dsets,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)

        total_time = []
        for cnt, data_dict in enumerate(data_loader):
            image = data_dict['image'][0].to(self.device)
            img_id = data_dict['img_id'][0]
            print('processing {}/{} image ...'.format(cnt + 1,
                                                      len(data_loader)))
            begin_time = time.time()
            with torch.no_grad():
                pr_decs = self.model(image)

            # self.imshow_heatmap(pr_decs, image)

            torch.cuda.synchronize(self.device)
            decoded_pts = []
            decoded_scores = []
            decoded_directions = []
            predictions = self.decoder.ctdet_decode(pr_decs)
            pts0, scores0, directions0 = func_utils.decode_prediction(
                predictions, dsets, args, img_id, down_ratio)
            decoded_pts.append(pts0)
            decoded_scores.append(scores0)
            decoded_directions.append(directions0)
            #nms
            results = {cat: [] for cat in dsets.category}
            for cat in dsets.category:
                if cat == 'background':
                    continue
                pts_cat = []
                scores_cat = []
                directions_cat = []
                for pts0, scores0, directions0 in zip(decoded_pts,
                                                      decoded_scores,
                                                      decoded_directions):
                    pts_cat.extend(pts0[cat])
                    scores_cat.extend(scores0[cat])
                    directions_cat.extend(directions0[cat])
                pts_cat = np.asarray(pts_cat, np.float32)
                scores_cat = np.asarray(scores_cat, np.float32)
                directions_cat = np.asarray(directions_cat, np.float32)
                if pts_cat.shape[0]:
                    nms_results = func_utils.non_maximum_suppression(
                        pts_cat, scores_cat, directions_cat)
                    results[cat].extend(nms_results)

            end_time = time.time()
            total_time.append(end_time - begin_time)

            #"""
            ori_image = dsets.load_image(cnt)
            height, width, _ = ori_image.shape
            # ori_image = cv2.resize(ori_image, (args.input_w, args.input_h))
            # ori_image = cv2.resize(ori_image, (args.input_w//args.down_ratio, args.input_h//args.down_ratio))
            #nms
            for cat in dsets.category:
                if cat == 'background':
                    continue
                result = results[cat]
                for pred in result:
                    score = pred[8]
                    direction = pred[9]
                    theta = pred[10]

                    tr = np.asarray([pred[0], pred[1]], np.float32)
                    br = np.asarray([pred[2], pred[3]], np.float32)
                    bl = np.asarray([pred[4], pred[5]], np.float32)
                    tl = np.asarray([pred[6], pred[7]], np.float32)

                    tt = (np.asarray(tl, np.float32) +
                          np.asarray(tr, np.float32)) / 2
                    rr = (np.asarray(tr, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    bb = (np.asarray(bl, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    ll = (np.asarray(tl, np.float32) +
                          np.asarray(bl, np.float32)) / 2

                    box = np.asarray([tl, tr, br, bl], np.float32)
                    cen_pts = np.mean(box, axis=0)

                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (0,255,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1)

                    ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1,
                                                 (255, 0, 255), 1, 1)

                    # draw main direction
                    if direction == 0:
                        # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,255,0),1,1)
                        cv2.line(ori_image, (int(tl[0]), int(tl[1])),
                                 (int(tr[0]), int(tr[1])), (0, 255, 0), 1, 1)
                    elif direction == 1:
                        # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (0,255,0),1,1)
                        cv2.line(ori_image, (int(tr[0]), int(tr[1])),
                                 (int(br[0]), int(br[1])), (0, 255, 0), 1, 1)
                    elif direction == 2:
                        # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                        cv2.line(ori_image, (int(bl[0]), int(bl[1])),
                                 (int(br[0]), int(br[1])), (0, 255, 0), 1, 1)
                    elif direction == 3:
                        # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (0,255,0),1,1)
                        cv2.line(ori_image, (int(tl[0]), int(tl[1])),
                                 (int(bl[0]), int(bl[1])), (0, 255, 0), 1, 1)

                    # draw main direction vector of directional BBA vectors
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (0,255,0),1,1)

                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tl[0]), int(tl[1])), (0,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tr[0]), int(tr[1])), (255,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(br[0]), int(br[1])), (0,255,0),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bl[0]), int(bl[1])), (255,0,0),1,1)
                    # box = cv2.boxPoints(cv2.minAreaRect(box))
                    # ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (0,255,0),1,1)
                    # cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (box[1][0], box[1][1]),
                    #             cv2.FONT_HERSHEY_COMPLEX, 0.5, (255,255,0), 1,1)

            if args.dataset == 'hrsc':
                gt_anno = dsets.load_annotation(cnt)
                for pts_4 in gt_anno['pts']:
                    bl = pts_4[0, :]
                    tl = pts_4[1, :]
                    tr = pts_4[2, :]
                    br = pts_4[3, :]
                    cen_pts = np.mean(pts_4, axis=0)
                    box = np.asarray([bl, tl, tr, br], np.float32)
                    box = np.int0(box)
                    cv2.drawContours(ori_image, [box], 0, (255, 255, 255), 1)

            # show result
            # cv2.imshow('pr_image', ori_image)
            # k = cv2.waitKey(0) & 0xFF
            # if k == ord('q'):
            #     cv2.destroyAllWindows()
            #     exit()

            # save result
            if not os.path.exists('result_img_dota'):
                os.mkdir('result_img_dota')
            cv2.imwrite(
                'result_img_dota/' +
                str(cnt + 1).zfill(len(str(len(data_loader)))) + '.png',
                ori_image)

            #"""

        total_time = total_time[1:]
        print('avg time is {}'.format(np.mean(total_time)))
        print('FPS is {}'.format(1. / np.mean(total_time)))
示例#4
0
    def test(self, args, down_ratio):
        save_path = 'weights_' + args.dataset
        self.model = self.load_model(self.model,
                                     os.path.join(save_path,
                                                  args.resume))  # 根据指定位置恢复模型
        self.model = self.model.to(self.device)
        self.model.eval()
        total_time = []
        images = os.listdir(
            'datasets/test/images')  # 读取images文件夹下面的图片,通过滑框形式检测大图
        for i in range(0, len(images)):
            print('processing {}/{} image ...'.format(i, len(images)))
            current_img = cv2.imread(images[i])
            w, h, _ = current_img.shape  # 获取当前图片的大小
            # 获取最佳的重叠率
            if (w < 500):
                sx = 0
            else:
                for t in range(0, 500):
                    if (w - t) % 500:
                        sx = t
            if (h < 500):
                sy = 0
            else:
                for t in range(0, 500):
                    if (w - t) % 500:
                        sy = t

            begin_time = time.time()
            with torch.no_grad():
                pr_decs = self.model(image)
            torch.cuda.synchronize(self.device)
            decoded_pts = []
            decoded_scores = []
            predictions = self.decoder.ctdet_decode(pr_decs)
            pts0, scores0 = func_utils.decode_prediction(
                predictions, dsets, args, img_id, down_ratio)
            decoded_pts.append(pts0)  # 检测结果的bbox信息
            decoded_scores.append(scores0)  # 检测结果的scores信息
            # 根据类别进行NMS操作
            results = {cat: [] for cat in dsets.category}
            for cat in dsets.category:
                if cat == 'background':
                    continue
                pts_cat = []
                scores_cat = []
                for pts0, scores0 in zip(decoded_pts, decoded_scores):
                    pts_cat.extend(pts0[cat])
                    scores_cat.extend(scores0[cat])
                pts_cat = np.asarray(pts_cat, np.float32)
                scores_cat = np.asarray(scores_cat, np.float32)
                if pts_cat.shape[0]:
                    nms_results = func_utils.non_maximum_suppression(
                        pts_cat, scores_cat)
                    results[cat].extend(nms_results)
            end_time = time.time()
            total_time.append(end_time - begin_time)
            ori_image = dsets.load_image(cnt)
            height, width, _ = ori_image.shape
            # 根据检测结果进行可视化操作
            for cat in dsets.category:
                if cat == 'background':
                    continue
                result = results[cat]
                for pred in result:
                    score = pred[-1]
                    tl = np.asarray([pred[0], pred[1]], np.float32)
                    tr = np.asarray([pred[2], pred[3]], np.float32)
                    br = np.asarray([pred[4], pred[5]], np.float32)
                    bl = np.asarray([pred[6], pred[7]], np.float32)
                    tt = (np.asarray(tl, np.float32) +
                          np.asarray(tr, np.float32)) / 2
                    rr = (np.asarray(tr, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    bb = (np.asarray(bl, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    ll = (np.asarray(tl, np.float32) +
                          np.asarray(bl, np.float32)) / 2
                    box = np.asarray([tl, tr, br, bl], np.float32)
                    cen_pts = np.mean(box, axis=0)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])),
                             (int(tt[0]), int(tt[1])), (0, 0, 255), 1, 1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])),
                             (int(rr[0]), int(rr[1])), (255, 0, 255), 1, 1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])),
                             (int(bb[0]), int(bb[1])), (0, 255, 0), 1, 1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])),
                             (int(ll[0]), int(ll[1])), (255, 0, 0), 1, 1)
                    ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1,
                                                 (255, 0, 255), 1, 1)
                    cv2.putText(ori_image, '{:.2f} {}'.format(score, cat),
                                (box[1][0], box[1][1]),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 255, 255),
                                1, 1)
            cv2.imshow('pr_image', ori_image)
            k = cv2.waitKey(0) & 0xFF
            if k == ord('q'):
                cv2.destroyAllWindows()
                exit()
        total_time = total_time[1:]
        print('avg time is {}'.format(np.mean(total_time)))
        print('FPS is {}'.format(1. / np.mean(total_time)))
示例#5
0
 def test(self, args, down_ratio):
     save_path = 'weights_' + args.dataset
     self.model = self.load_model(self.model,
                                  os.path.join(save_path,
                                               args.resume))  # 根据指定位置恢复模型
     self.model = self.model.to(self.device)
     self.model.eval()
     dataset_module = self.dataset[args.dataset]  # DOTA类
     dsets = dataset_module(data_dir=args.data_dir,
                            phase='test',
                            input_h=args.input_h,
                            input_w=args.input_w,
                            down_ratio=down_ratio)
     data_loader = torch.utils.data.DataLoader(dsets,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
     total_time = []
     for cnt, data_dict in enumerate(data_loader):
         image = data_dict['image'][0].to(self.device)
         img_id = data_dict['img_id'][0]
         print('processing {}/{} image ...'.format(cnt, len(data_loader)))
         begin_time = time.time()
         with torch.no_grad():
             pr_decs = self.model(image)
         torch.cuda.synchronize(self.device)
         decoded_pts = []
         decoded_scores = []
         predictions = self.decoder.ctdet_decode(pr_decs)
         pts0, scores0 = func_utils.decode_prediction(
             predictions, dsets, args, img_id, down_ratio)
         decoded_pts.append(pts0)  # 检测结果的bbox信息
         decoded_scores.append(scores0)  # 检测结果的scores信息
         # 根据类别进行NMS操作
         results = {cat: [] for cat in dsets.category}
         for cat in dsets.category:
             if cat == 'background':
                 continue
             pts_cat = []
             scores_cat = []
             for pts0, scores0 in zip(decoded_pts, decoded_scores):
                 pts_cat.extend(pts0[cat])
                 scores_cat.extend(scores0[cat])
             pts_cat = np.asarray(pts_cat, np.float32)
             scores_cat = np.asarray(scores_cat, np.float32)
             if pts_cat.shape[0]:
                 nms_results = func_utils.non_maximum_suppression(
                     pts_cat, scores_cat)
                 results[cat].extend(nms_results)
         end_time = time.time()
         total_time.append(end_time - begin_time)
         ori_image = dsets.load_image(cnt)
         height, width, _ = ori_image.shape
         # 根据检测结果进行可视化操作
         for cat in dsets.category:
             if cat == 'background':
                 continue
             result = results[cat]
             for pred in result:
                 score = pred[-1]
                 tl = np.asarray([pred[0], pred[1]], np.float32)
                 tr = np.asarray([pred[2], pred[3]], np.float32)
                 br = np.asarray([pred[4], pred[5]], np.float32)
                 bl = np.asarray([pred[6], pred[7]], np.float32)
                 tt = (np.asarray(tl, np.float32) +
                       np.asarray(tr, np.float32)) / 2
                 rr = (np.asarray(tr, np.float32) +
                       np.asarray(br, np.float32)) / 2
                 bb = (np.asarray(bl, np.float32) +
                       np.asarray(br, np.float32)) / 2
                 ll = (np.asarray(tl, np.float32) +
                       np.asarray(bl, np.float32)) / 2
                 box = np.asarray([tl, tr, br, bl], np.float32)
                 cen_pts = np.mean(box, axis=0)
                 #cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1)
                 #cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (255,0,255),1,1)
                 #cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                 #cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1)
                 ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1,
                                              (255, 0, 255), 1, 1)
                 cv2.putText(ori_image, '{:.2f} {}'.format(score, cat),
                             (box[1][0], box[1][1]),
                             cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 255, 255),
                             1, 1)
         '''
         cv2.imshow('pr_image', ori_image)
         k = cv2.waitKey(0) & 0xFF
         if k == ord('q'):
             cv2.destroyAllWindows()
             exit()
        '''
         cv2.imwrite('results/{}.png'.format(cnt), ori_image)
     total_time = total_time[1:]
     print('avg time is {}'.format(np.mean(total_time)))
     print('FPS is {}'.format(1. / np.mean(total_time)))
示例#6
0
    def test(self, args, down_ratio):
        save_path = 'weights_'+args.dataset
        predict_path = './predict_fpt'
        mkdir(predict_path)
        self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
        self.model = self.model.to(self.device)
        self.model.eval()

        dataset_module = self.dataset[args.dataset]
        dsets = dataset_module(data_dir=args.data_dir,
                               phase='test',
                               input_h=args.input_h,
                               input_w=args.input_w,
                               down_ratio=down_ratio)
        data_loader = torch.utils.data.DataLoader(dsets,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)

        total_time = []
        for cnt, data_dict in enumerate(data_loader):
            image = data_dict['image'][0].to(self.device)
            img_id = data_dict['img_id'][0]
            print('processing {}/{} image ...'.format(cnt, len(data_loader)))
            begin_time = time.time()
            with torch.no_grad():
                pr_decs = self.model(image)

            #self.imshow_heatmap(pr_decs[2], image)

            torch.cuda.synchronize(self.device)
            decoded_pts = []
            decoded_scores = []
            predictions = self.decoder.ctdet_decode(pr_decs)
            pts0, scores0 = func_utils.decode_prediction(predictions, dsets, args, img_id, down_ratio)
            decoded_pts.append(pts0)
            decoded_scores.append(scores0)
            #nms
            results = {cat: [] for cat in dsets.category}
            for cat in dsets.category:
                if cat == 'background':
                    continue
                pts_cat = []
                scores_cat = []
                for pts0, scores0 in zip(decoded_pts, decoded_scores):
                    pts_cat.extend(pts0[cat])
                    scores_cat.extend(scores0[cat])
                pts_cat = np.asarray(pts_cat, np.float32)
                scores_cat = np.asarray(scores_cat, np.float32)
                if pts_cat.shape[0]:
                    nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat)
                    results[cat].extend(nms_results)

            end_time = time.time()
            total_time.append(end_time-begin_time)

            #"""
            ori_image = dsets.load_image(cnt)
            height, width, _ = ori_image.shape
            # ori_image = cv2.resize(ori_image, (args.input_w, args.input_h))
            # ori_image = cv2.resize(ori_image, (args.input_w//args.down_ratio, args.input_h//args.down_ratio))
            #nms
            final_result = []
            for cat in dsets.category:
                if cat == 'background':
                    continue
                result = results[cat]
                for pred in result:
                    final_result.append([pred[0], pred[1], pred[2], pred[3], pred[4], pred[5], pred[6], pred[7], cat])
                    for ann in final_result:
                        with open(f"{predict_path}/{img_id}.txt", 'a+') as f:
                            f.write(f"{ann[0]:.2f} {ann[1]:.2f} {ann[2]:.2f} {ann[3]:.2f} {ann[4]:.2f} {ann[5]:.2f} {ann[6]:.2f} {ann[7]:.2f} {ann[8]} \n")
                    score = pred[-1]
                    tl = np.asarray([pred[0], pred[1]], np.float32)
                    tr = np.asarray([pred[2], pred[3]], np.float32)
                    br = np.asarray([pred[4], pred[5]], np.float32)
                    bl = np.asarray([pred[6], pred[7]], np.float32)

                    tt = (np.asarray(tl, np.float32) + np.asarray(tr, np.float32)) / 2
                    rr = (np.asarray(tr, np.float32) + np.asarray(br, np.float32)) / 2
                    bb = (np.asarray(bl, np.float32) + np.asarray(br, np.float32)) / 2
                    ll = (np.asarray(tl, np.float32) + np.asarray(bl, np.float32)) / 2

                    box = np.asarray([tl, tr, br, bl], np.float32)
                    cen_pts = np.mean(box, axis=0)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (255,0,255),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1)

                    ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (255,0,255),1,1)
                    cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (box[1][0], box[1][1]),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (0,255,255), 1,1)

            if args.dataset == 'hrsc':
                gt_anno = dsets.load_annotation(cnt)
                for pts_4 in gt_anno['pts']:
                    bl = pts_4[0, :]
                    tl = pts_4[1, :]
                    tr = pts_4[2, :]
                    br = pts_4[3, :]
                    cen_pts = np.mean(pts_4, axis=0)
                    box = np.asarray([bl, tl, tr, br], np.float32)
                    box = np.int0(box)
                    cv2.drawContours(ori_image, [box], 0, (255, 255, 255), 1)

            plt.imshow(ori_image)
            plt.show()

        total_time = total_time[1:]
        print('avg time is {}'.format(np.mean(total_time)))
        print('FPS is {}'.format(1./np.mean(total_time)))
示例#7
0
    def test(self, args, down_ratio):
        save_path = 'weights_' + args.dataset
        self.model = self.load_model(self.model,
                                     os.path.join(save_path, args.resume))
        self.model = self.model.to(self.device)
        self.model.eval()

        dataset_module = self.dataset[args.dataset]
        dsets = dataset_module(data_dir=args.data_dir,
                               phase='test',
                               input_h=args.input_h,
                               input_w=args.input_w,
                               down_ratio=down_ratio)
        data_loader = torch.utils.data.DataLoader(dsets,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)

        total_time = []
        for cnt, data_dict in enumerate(data_loader):
            image = data_dict['image'][0].to(self.device)
            img_id = data_dict['img_id'][0]
            print('processing {}/{} image ...'.format(cnt, len(data_loader)))
            begin_time = time.time()
            with torch.no_grad():
                pr_decs = self.model(image)

            #self.imshow_heatmap(pr_decs[2], image)

            torch.cuda.synchronize(self.device)
            decoded_pts = []
            decoded_scores = []
            decoded_bd_pts = []  ##!!##
            # (batch, num_targets, )
            predictions = self.decoder.ctdet_decode(pr_decs)
            pts0, scores0, bd_pts0 = func_utils.decode_prediction(
                predictions, dsets, args, img_id, down_ratio)  ##!!##
            decoded_pts.append(pts0)
            decoded_scores.append(scores0)
            decoded_bd_pts.append(bd_pts0)  ##!!##
            #nms
            results = {cat: [] for cat in dsets.category}
            bd_results = {cat: [] for cat in dsets.category}  ##!!##
            for cat in dsets.category:
                if cat == 'background':
                    continue
                pts_cat = []
                scores_cat = []
                bd_pts_cat = []  ##!!##

                for pts0, scores0, bd_pts0 in zip(decoded_pts, decoded_scores,
                                                  decoded_bd_pts):  ##!!##
                    pts_cat.extend(pts0[cat])
                    scores_cat.extend(scores0[cat])
                    bd_pts_cat.extend(bd_pts0[cat])  ##!!##

                pts_cat = np.asarray(pts_cat, np.float32)
                scores_cat = np.asarray(scores_cat, np.float32)
                bd_pts_cat = np.asarray(bd_pts_cat, np.float32)  ##!!##

                if pts_cat.shape[0]:
                    nms_results, nms_bds = func_utils.non_maximum_suppression(
                        pts_cat, scores_cat, bd_pts_cat)
                    results[cat].extend(nms_results)
                    bd_results[cat].extend(nms_bds)

            end_time = time.time()
            total_time.append(end_time - begin_time)

            #"""
            ori_image = dsets.load_image(cnt)
            height, width, _ = ori_image.shape
            # ori_image = cv2.resize(ori_image, (args.input_w, args.input_h))
            # ori_image = cv2.resize(ori_image, (args.input_w//args.down_ratio, args.input_h//args.down_ratio))
            #nms
            for cat in dsets.category:
                if cat == 'background':
                    continue
                result = results[cat]
                bd = bd_results[cat]

                for pred, bds in zip(result, bd):

                    score = pred[-1]
                    tl = np.asarray([pred[0], pred[1]], np.float32)
                    tr = np.asarray([pred[2], pred[3]], np.float32)
                    br = np.asarray([pred[4], pred[5]], np.float32)
                    bl = np.asarray([pred[6], pred[7]], np.float32)

                    tt = (np.asarray(tl, np.float32) +
                          np.asarray(tr, np.float32)) / 2
                    rr = (np.asarray(tr, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    bb = (np.asarray(bl, np.float32) +
                          np.asarray(br, np.float32)) / 2
                    ll = (np.asarray(tl, np.float32) +
                          np.asarray(bl, np.float32)) / 2

                    box = np.asarray([tl, tr, br, bl], np.float32)
                    cen_pts = np.mean(box, axis=0)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (255,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1)

                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tl[0]), int(tl[1])), (0,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tr[0]), int(tr[1])), (255,0,255),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(br[0]), int(br[1])), (0,255,0),1,1)
                    # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bl[0]), int(bl[1])), (255,0,0),1,1)

                    ######## 画绿色粗矩形框 #########
                    ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1,
                                                 (0, 255, 0), 4, 1)
                    ###################################

                    ############## 边界点 ###############
                    # bds (2, 2N)
                    points = []
                    for i in range(bds.shape[1]):
                        points.append([bds[0, i], bds[1, i]])

                    for point in points:
                        point = tuple([int(i) for i in point])
                        ori_image = cv2.circle(ori_image, point, 3,
                                               (0, 0, 255), 4)
                    ####################################

                    # box = cv2.boxPoints(cv2.minAreaRect(box))
                    # ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (0,255,0),1,1)
                    # cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (box[1][0], box[1][1]),
                    #             cv2.FONT_HERSHEY_COMPLEX, 0.5, (0,255,255), 1,1)

            # if args.dataset == 'ssdd': #
            #     gt_anno = dsets.load_annotation(cnt)
            #     for pts_4 in gt_anno['pts']:
            #         bl = pts_4[0, :]
            #         tl = pts_4[1, :]
            #         tr = pts_4[2, :]
            #         br = pts_4[3, :]
            #         cen_pts = np.mean(pts_4, axis=0)
            #         box = np.asarray([bl, tl, tr, br], np.float32)
            #         box = np.int0(box)
            #         cv2.drawContours(ori_image, [box], 0, (255, 0, 255), 4)
            cv2.imwrite('./result_images/{}_det.jpg'.format(img_id),
                        ori_image)  #
            # cv2.imshow('pr_image', ori_image)
            # k = cv2.waitKey(0) & 0xFF
            # if k == ord('q'):
            #     cv2.destroyAllWindows()
            #     exit()
            #"""

        total_time = total_time[1:]
        print('avg time is {}'.format(np.mean(total_time)))
        print('FPS is {}'.format(1. / np.mean(total_time)))
示例#8
0
    def test(self, args, down_ratio):
        save_path = 'weights_'+args.dataset
        self.model = self.load_model(self.model, os.path.join(save_path, args.resume))  # 根据指定位置恢复模型
        self.model = self.model.to(self.device)
        self.model.eval()
        dataset_module = self.dataset[args.dataset]      # DOTA类
        dsets = dataset_module(data_dir=args.data_dir,
                               phase='test',
                               input_h=args.input_h,
                               input_w=args.input_w,
                               down_ratio=down_ratio)
        data_loader = torch.utils.data.DataLoader(dsets,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)
        total_time = []
        out_txt = ""
        txt_dir = os.path.join("txt_results",args.out_dir,"per_image_txt")
        #txt_dir = "datasets/test_split/result_small_T5_crop_608_e100_c2_big_T5608_c1_3000"
        if not os.path.exists(txt_dir):
                os.makedirs(txt_dir)
        
        for cnt, data_dict in enumerate(data_loader):
            image = data_dict['image'][0].to(self.device)
            img_id = data_dict['img_id'][0]
            print('processing {}/{} image ...'.format(cnt, len(data_loader)))
            begin_time = time.time()
            with torch.no_grad():
                pr_decs = self.model(image)
            torch.cuda.synchronize(self.device)
            decoded_pts = []
            decoded_scores = []
            predictions = self.decoder.ctdet_decode(pr_decs)
            pts0, scores0 = func_utils.decode_prediction(predictions, dsets, args, img_id, down_ratio)
            decoded_pts.append(pts0)          # 检测结果的bbox信息
            decoded_scores.append(scores0)    # 检测结果的scores信息
            # 根据类别进行NMS操作
            results = {cat:[] for cat in dsets.category}
            for cat in dsets.category:
                if cat == 'background':
                    continue
                pts_cat = []
                scores_cat = []
                for pts0, scores0 in zip(decoded_pts, decoded_scores):
                    pts_cat.extend(pts0[cat])
                    scores_cat.extend(scores0[cat])
                pts_cat = np.asarray(pts_cat, np.float32)
                scores_cat = np.asarray(scores_cat, np.float32)
                if pts_cat.shape[0]:
                    nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat)
                    results[cat].extend(nms_results)
            end_time = time.time()
            total_time.append(end_time-begin_time)
            ori_image = dsets.load_image(cnt)
            height, width, _ = ori_image.shape
            per_image_txt=""
            # 根据检测结果进行可视化操作
            for cat in dsets.category:
                if cat == 'background':
                    continue
                result = results[cat]
                result.sort(key=(lambda x:x[-1]),reverse=True) # 按照置信度降序排序
                #result = result[:30]
                for pred in result:
                    score = pred[-1]
                    tl = np.asarray([pred[0], pred[1]], np.float32)
                    tr = np.asarray([pred[2], pred[3]], np.float32)
                    br = np.asarray([pred[4], pred[5]], np.float32)
                    bl = np.asarray([pred[6], pred[7]], np.float32)
                    # add 5 为了输出到txt
                    pred = list(pred)
                    pred.insert(0,pred.pop())
                    pred[1:] = [int(pos) for pos in pred[1:]]
                    pred.insert(0,int(cat))
                    pred.insert(0,str(img_id)+'.tif')
                    box = np.asarray([tl, tr, br, bl], np.float32)
                    if min(pred[3:])<0 or max(pred[3:])>1023:
                        result = inter_poly(pred[3:])
                        if result==None:
                            continue
                        else:
                            pred[3:]=result    
                    rect = cv2.minAreaRect(np.array(pred[3:]).reshape(4,2))
                    # 卡10000
                    if rect[1][0]*rect[1][1]>=1000:
                        per_image_txt+=" ".join(map(str,pred))+'\n'
                    """
                    tt = (np.asarray(tl, np.float32) + np.asarray(tr, np.float32)) / 2
                    rr = (np.asarray(tr, np.float32) + np.asarray(br, np.float32)) / 2
                    bb = (np.asarray(bl, np.float32) + np.asarray(br, np.float32)) / 2
                    ll = (np.asarray(tl, np.float32) + np.asarray(bl, np.float32)) / 2

                    cen_pts = np.mean(box, axis=0)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (255,0,255),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1)
                    cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1)
                    """

                    ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (255,0,255),1,1)
                    cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (box[1][0], box[1][1]),
                                cv2.FONT_HERSHEY_COMPLEX, 0.5, (0,255,255), 1,1)
            #cv2.imshow('pr_image', ori_image)
            #cv2.imwrite('results/{}.png'.format(img_id),ori_image)
            out_txt+=per_image_txt
            with open(os.path.join(txt_dir,str(img_id)+'.txt'),'a+') as f:
                f.write(per_image_txt)
            #k = cv2.waitKey(0) & 0xFF
            #if k == ord('q'):
            #    cv2.destroyAllWindows()
            #   exit()
        total_time = total_time[1:]
        print('avg time is {}'.format(np.mean(total_time)))
        print('FPS is {}'.format(1./np.mean(total_time)))
        """
        with open('final_result.txt','a+') as f:
            f.write(out_txt)
        """
        total_txt = ""
        i = 0
        print(len(os.listdir(txt_dir)))
        for txt in os.listdir(txt_dir):
            with open(os.path.join(txt_dir,txt),'r') as f:                
                total_txt+="".join(f.readlines())
        with open(os.path.join("txt_results",args.out_dir,'total_result_'+args.out_dir+'.txt'),'w+') as f:
            f.write(total_txt)