Exemplo n.º 1
0
def main():
    RF_module = RefineModule(refine_path,
                             selector_path,
                             search_factor=sr,
                             input_sz=input_sz)

    # refine_method = args.refine_method
    model_name = 'siamrpn_' + RF_type + '_m2b_{}'.format(args.thres)

    snapshot_path = os.path.join(
        project_path_, 'experiments/%s/model.pth' % args.tracker_name)
    config_path = os.path.join(
        project_path_, 'experiments/%s/config.yaml' % args.tracker_name)

    cfg.merge_from_file(config_path)

    # create model
    model = ModelBuilder()  # a sub-class of `torch.nn.Module`
    model = load_pretrain(model, snapshot_path).cuda().eval()

    # build tracker
    tracker = build_tracker(
        model)  # tracker is a object consisting of NN and some post-processing

    # create dataset
    dataset = DatasetFactory.create_dataset(name=args.dataset,
                                            dataset_root=dataset_root_,
                                            load_img=False)

    # OPE tracking
    for v_idx, video in enumerate(dataset):
        if os.path.exists(
                os.path.join(save_dir, args.dataset, model_name,
                             '{}.txt'.format(video.name))):
            continue
        if args.video != '':
            # test one special video
            if video.name != args.video:
                continue
        toc = 0
        pred_bboxes = []
        scores = []
        track_times = []
        for idx, (img, gt_bbox) in enumerate(video):
            tic = cv2.getTickCount()
            if idx == 0:
                H, W, _ = img.shape
                cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox))
                gt_bbox_ = [cx - (w - 1) / 2, cy - (h - 1) / 2, w, h]
                tracker.init(img, gt_bbox_)
                '''##### initilize refinement module for specific video'''
                RF_module.initialize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
                                     np.array(gt_bbox_))
                pred_bbox = gt_bbox_
                scores.append(None)
                pred_bboxes.append(pred_bbox)

            else:
                outputs = tracker.track(img)
                pred_bbox = outputs['bbox']
                '''##### refine tracking results #####'''
                mask_pred = RF_module.get_mask(
                    cv2.cvtColor(img, cv2.COLOR_BGR2RGB), np.array(pred_bbox))
                pred_bbox = mask2bbox(mask_pred,
                                      pred_bbox,
                                      MASK_THRESHOLD=args.thres)
                x1, y1, w, h = pred_bbox.tolist()
                '''add boundary and min size limit'''
                x1, y1, x2, y2 = bbox_clip(x1, y1, x1 + w, y1 + h, (H, W))
                w = x2 - x1
                h = y2 - y1

                pred_bbox = np.array([x1, y1, w, h])
                tracker.center_pos = np.array([x1 + w / 2, y1 + h / 2])
                tracker.size = np.array([w, h])

                pred_bboxes.append(pred_bbox)
                scores.append(outputs['best_score'])
            toc += cv2.getTickCount() - tic
            track_times.append(
                (cv2.getTickCount() - tic) / cv2.getTickFrequency())
            if idx == 0:
                cv2.destroyAllWindows()
            if args.vis and idx > 0:
                gt_bbox = list(map(int, gt_bbox))
                pred_bbox = list(map(int, pred_bbox))
                cv2.rectangle(
                    img, (gt_bbox[0], gt_bbox[1]),
                    (gt_bbox[0] + gt_bbox[2], gt_bbox[1] + gt_bbox[3]),
                    (0, 255, 0), 3)
                cv2.rectangle(
                    img, (pred_bbox[0], pred_bbox[1]),
                    (pred_bbox[0] + pred_bbox[2], pred_bbox[1] + pred_bbox[3]),
                    (0, 255, 255), 3)
                cv2.putText(img, str(idx), (40, 40), cv2.FONT_HERSHEY_SIMPLEX,
                            1, (0, 255, 255), 2)
                cv2.imshow(video.name, img)
                cv2.waitKey(1)
        toc /= cv2.getTickFrequency()

        # save results
        model_path = os.path.join(save_dir, args.dataset,
                                  model_name + '_' + str(selector_path))
        if not os.path.isdir(model_path):
            os.makedirs(model_path)
        result_path = os.path.join(model_path, '{}.txt'.format(video.name))
        with open(result_path, 'w') as f:
            for x in pred_bboxes:
                f.write(','.join([str(i) for i in x]) + '\n')
        print('({:3d}) Video: {:12s} Time: {:5.1f}s Speed: {:3.1f}fps'.format(
            v_idx + 1, video.name, toc, idx / toc))
Exemplo n.º 2
0
def main():
    # create tracker
    tracker_info = Tracker(args.tracker_name, args.tracker_param, None)
    params = tracker_info.get_parameters()
    params.visualization = args.vis
    params.debug = args.debug
    params.visdom_info = {
        'use_visdom': False,
        'server': '127.0.0.1',
        'port': 8097
    }
    tracker = tracker_info.tracker_class(params)
    '''Refinement module'''
    RF_module = RefineModule(refine_path,
                             selector_path,
                             search_factor=sr,
                             input_sz=input_sz)
    model_name = args.tracker_name + '_' + args.tracker_param + '{}-{}'.format(
        RF_type, selector_path) + '_%d' % (args.run_id)

    # create dataset
    dataset = DatasetFactory.create_dataset(name=args.dataset,
                                            dataset_root=dataset_root_,
                                            load_img=False)

    # OPE tracking
    for v_idx, video in enumerate(dataset):
        color = np.array(COLORS[random.randint(0,
                                               len(COLORS) - 1)])[None,
                                                                  None, ::-1]
        vis_result = os.path.join(
            '/home/zxy/Desktop/AlphaRefine/CVPR21/material/quality_analysis/mask_vis',
            '{}'.format(video.name))

        if args.video != '':
            # test one special video
            if video.name != args.video:
                continue
            else:
                print()

        if not os.path.exists(vis_result):
            os.makedirs(vis_result)

        toc = 0
        pred_bboxes = []
        scores = []
        track_times = []
        for idx, (img, gt_bbox) in enumerate(video):
            '''get RGB format image'''
            img_RGB = img[:, :, ::-1].copy()  # BGR --> RGB
            tic = cv2.getTickCount()
            if idx == 0:
                H, W, _ = img.shape
                cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox))
                gt_bbox_ = [cx - (w - 1) / 2, cy - (h - 1) / 2, w, h]
                '''Initialize'''
                gt_bbox_np = np.array(gt_bbox_)
                gt_bbox_torch = torch.from_numpy(gt_bbox_np.astype(np.float32))
                init_info = {}
                init_info['init_bbox'] = gt_bbox_torch
                _ = tracker.initialize(img_RGB, init_info)
                '''##### initilize refinement module for specific video'''
                RF_module.initialize(img_RGB, np.array(gt_bbox_))

                pred_bbox = gt_bbox_
                scores.append(None)
                pred_bboxes.append(pred_bbox)

            else:
                '''Track'''
                outputs = tracker.track(img_RGB)
                pred_bbox = outputs['target_bbox']
                '''##### refine tracking results #####'''
                pred_bbox = RF_module.refine(
                    cv2.cvtColor(img, cv2.COLOR_BGR2RGB), np.array(pred_bbox))

                x1, y1, w, h = pred_bbox.tolist()
                '''add boundary and min size limit'''
                x1, y1, x2, y2 = bbox_clip(x1, y1, x1 + w, y1 + h, (H, W))
                w = x2 - x1
                h = y2 - y1
                new_pos = torch.from_numpy(
                    np.array([y1 + h / 2, x1 + w / 2]).astype(np.float32))
                new_target_sz = torch.from_numpy(
                    np.array([h, w]).astype(np.float32))
                new_scale = torch.sqrt(new_target_sz.prod() /
                                       tracker.base_target_sz.prod())
                ##### update
                tracker.pos = new_pos.clone()
                tracker.target_sz = new_target_sz
                tracker.target_scale = new_scale

                mask_pred = RF_module.get_mask(
                    cv2.cvtColor(img, cv2.COLOR_BGR2RGB), np.array(pred_bbox))
                from external.pysot.toolkit.visualization import draw_mask
                draw_mask(img,
                          mask_pred,
                          idx=idx,
                          show=True,
                          save_dir='dimpsuper_armask_crocodile-3')

                pred_bboxes.append(pred_bbox)
                # scores.append(outputs['best_score'])
            toc += cv2.getTickCount() - tic
            track_times.append(
                (cv2.getTickCount() - tic) / cv2.getTickFrequency())
            if idx == 0:
                cv2.destroyAllWindows()
            if args.vis and idx > 0:
                im4show = img
                mask_pred = np.uint8(mask_pred > 0.5)[:, :, None]
                contours, _ = cv2.findContours(mask_pred.squeeze(),
                                               cv2.RETR_LIST,
                                               cv2.CHAIN_APPROX_SIMPLE)
                im4show = im4show * (1 - mask_pred) + np.uint8(
                    im4show * mask_pred /
                    2) + mask_pred * np.uint8(color) * 128
                pred_bbox = list(map(int, pred_bbox))
                # gt_bbox = list(map(int, gt_bbox))
                # cv2.rectangle(im4show, (gt_bbox[0], gt_bbox[1]),
                #               (gt_bbox[0]+gt_bbox[2], gt_bbox[1]+gt_bbox[3]), (0, 255, 0), 3)

                # cv2.rectangle(im4show, (pred_bbox[0], pred_bbox[1]),
                #               (pred_bbox[0]+pred_bbox[2], pred_bbox[1]+pred_bbox[3]), color[::-1].squeeze().tolist(), 3)

                cv2.drawContours(im4show, contours, -1, color[::-1].squeeze(),
                                 2)
                cv2.putText(im4show, str(idx), (40, 40),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
                # cv2.imshow(video.name, im4show)
                cv2.imwrite(os.path.join(vis_result, '{:06}.jpg'.format(idx)),
                            im4show)
                cv2.waitKey(1)
        toc /= cv2.getTickFrequency()