Ejemplo n.º 1
0
class YTVOSDataset(CustomDataset):
    CLASSES = ('person', 'giant_panda', 'lizard', 'parrot', 'skateboard',
               'sedan', 'ape', 'dog', 'snake', 'monkey', 'hand', 'rabbit',
               'duck', 'cat', 'cow', 'fish', 'train', 'horse', 'turtle',
               'bear', 'motorbike', 'giraffe', 'leopard', 'fox', 'deer', 'owl',
               'surfboard', 'airplane', 'truck', 'zebra', 'tiger', 'elephant',
               'snowboard', 'boat', 'shark', 'mouse', 'frog', 'eagle',
               'earless_seal', 'tennis_racket')

    def __init__(self,
                 ann_file,
                 img_prefix,
                 img_scale,
                 img_norm_cfg,
                 size_divisor=None,
                 proposal_file=None,
                 num_max_proposals=1000,
                 flip_ratio=0,
                 with_mask=True,
                 with_crowd=True,
                 with_label=True,
                 with_track=False,
                 extra_aug=None,
                 aug_ref_bbox_param=None,
                 resize_keep_ratio=True,
                 test_mode=False,
                 every_frame=False,
                 is_flow=False,
                 flow_test=False):
        # prefix of images path
        self.img_prefix = img_prefix

        # load annotations (and proposals)
        self.vid_infos = self.load_annotations(ann_file)

        self.every_frame = every_frame
        self.is_flow = is_flow
        self.flow_test = flow_test
        if self.flow_test or self.is_flow:
            self.cuda = True
        self.cuda = False
        if self.cuda:
            from mmcv import Config
            from mmdet.models import build_detector
            from mmcv.runner import load_checkpoint
            cfg = Config.fromfile(
                "../configs/masktrack_rcnn_r50_fpn_1x_flow_youtubevos.py")
            self.det_model = build_detector(cfg.model,
                                            train_cfg=cfg.train_cfg,
                                            test_cfg=cfg.test_cfg)
            load_checkpoint(self.det_model,
                            "../results/20200312-180434/epoch_9.pth")
            self.det_model = self.det_model.cuda()
            self.det_model.eval()
            for param in self.det_model.parameters():
                param.requires_grad = False

        # Set indexes for data loading
        img_ids = []  # training frames which have annotations
        img_ids_all = []  # all training frames
        img_ids_pairs = []  # flow data pairs
        for idx, vid_info in enumerate(self.vid_infos):
            vid_name = vid_info['filenames'][0].split('/')[0]
            folder_path = osp.join(self.img_prefix, vid_name)
            files = os.listdir(folder_path)
            files.sort()
            vid_info['filenames_all'] = [
                osp.join(vid_name, file) for file in files
            ]
            for _id in range(len(files)):
                img_ids_all.append((idx, _id))
                is_anno = vid_info['filenames_all'][_id] in vid_info[
                    'filenames']
                if is_anno and _id > 0:  # having annotation and is not the first frame.
                    ann_idx = vid_info['filenames'].index(
                        vid_info['filenames_all'][_id])
                    ann = self.get_ann_info(idx, ann_idx)
                    gt_bboxes = ann['bboxes']
                    # skip the image if there is no valid gt bbox
                    if len(gt_bboxes) == 0:
                        continue
                    # random select key frame
                    key_id = _id - np.random.randint(1, min(10, _id))
                    img_ids_pairs.append(((idx, key_id), (idx, _id)))
            for frame_id in range(len(vid_info['filenames'])):
                img_ids.append((idx, frame_id))

        self.img_ids = img_ids
        self.img_ids_all = img_ids_all
        self.img_ids_pairs = img_ids_pairs

        if proposal_file is not None:
            self.proposals = self.load_proposals(proposal_file)
        else:
            self.proposals = None
        # filter images with no annotation during training
        if not test_mode:
            valid_inds = [
                i for i, (v, f) in enumerate(self.img_ids)
                if len(self.get_ann_info(v, f)['bboxes'])
            ]
            self.img_ids = [self.img_ids[i] for i in valid_inds]

        # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
        self.img_scales = img_scale if isinstance(img_scale,
                                                  list) else [img_scale]
        assert mmcv.is_list_of(self.img_scales, tuple)
        # normalization configs
        self.img_norm_cfg = img_norm_cfg

        # max proposals per image
        self.num_max_proposals = num_max_proposals
        # flip ratio
        self.flip_ratio = flip_ratio
        assert flip_ratio >= 0 and flip_ratio <= 1
        # padding border to ensure the image size can be divided by
        # size_divisor (used for FPN)
        self.size_divisor = size_divisor

        # with mask or not (reserved field, takes no effect)
        self.with_mask = with_mask
        # some datasets provide bbox annotations as ignore/crowd/difficult,
        # if `with_crowd` is True, then these info is returned.
        self.with_crowd = with_crowd
        # with label is False for RPN
        self.with_label = with_label
        self.with_track = with_track
        # params for augmenting bbox in the reference frame
        self.aug_ref_bbox_param = aug_ref_bbox_param
        # in test mode or not
        self.test_mode = test_mode

        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()
        # transforms
        self.img_transform = ImageTransform(size_divisor=self.size_divisor,
                                            **self.img_norm_cfg)
        self.bbox_transform = BboxTransform()
        self.mask_transform = MaskTransform()
        self.numpy2tensor = Numpy2Tensor()

        # if use extra augmentation
        if extra_aug is not None:
            self.extra_aug = ExtraAugmentation(**extra_aug)
        else:
            self.extra_aug = None

        # image rescale if keep ratio
        self.resize_keep_ratio = resize_keep_ratio

    def __len__(self):
        if self.every_frame:
            return len(self.img_ids_all)
        elif self.is_flow:
            return len(self.img_ids_pairs)
        else:
            return len(self.img_ids)

    def __getitem__(self, idx):
        if self.test_mode:
            if self.every_frame:
                return self.prepare_test_img(self.img_ids_all[idx])
            else:
                return self.prepare_test_img(self.img_ids[idx])
        if self.is_flow:
            if self.flow_test:
                data = self.prepare_train_flow_test_img(
                    self.img_ids_pairs[idx])
            else:
                data = self.prepare_train_flow_img(self.img_ids_pairs[idx])
        else:
            data = self.prepare_train_img(self.img_ids[idx])
        return data

    def load_annotations(self, ann_file):
        self.ytvos = YTVOS(ann_file)
        self.cat_ids = self.ytvos.getCatIds()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.vid_ids = self.ytvos.getVidIds()
        vid_infos = []
        for i in self.vid_ids:
            info = self.ytvos.loadVids([i])[0]
            info['filenames'] = info['file_names']
            vid_infos.append(info)
        return vid_infos

    def get_ann_info(self, idx, frame_id):
        vid_id = self.vid_infos[idx]['id']
        ann_ids = self.ytvos.getAnnIds(vidIds=[vid_id])
        ann_info = self.ytvos.loadAnns(ann_ids)
        return self._parse_ann_info(ann_info, frame_id)

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            vid_id, _ = self.img_ids[i]
            vid_info = self.vid_infos[vid_id]
            if vid_info['width'] / vid_info['height'] > 1:
                self.flag[i] = 1

    def bbox_aug(self, bbox, img_size):
        assert self.aug_ref_bbox_param is not None
        center_off = self.aug_ref_bbox_param[0]
        size_perturb = self.aug_ref_bbox_param[1]

        n_bb = bbox.shape[0]
        # bbox center offset
        center_offs = (2 * np.random.rand(n_bb, 2) - 1) * center_off
        # bbox resize ratios
        resize_ratios = (2 * np.random.rand(n_bb, 2) - 1) * size_perturb + 1
        # bbox: x1, y1, x2, y2
        centers = (bbox[:, :2] + bbox[:, 2:]) / 2.
        sizes = bbox[:, 2:] - bbox[:, :2]
        new_centers = centers + center_offs * sizes
        new_sizes = sizes * resize_ratios
        new_x1y1 = new_centers - new_sizes / 2.
        new_x2y2 = new_centers + new_sizes / 2.
        c_min = [0, 0]
        c_max = [img_size[1], img_size[0]]
        new_x1y1 = np.clip(new_x1y1, c_min, c_max)
        new_x2y2 = np.clip(new_x2y2, c_min, c_max)
        bbox = np.hstack((new_x1y1, new_x2y2)).astype(np.float32)
        return bbox

    def sample_ref(self, idx):
        # sample another frame in the same sequence as reference
        vid, frame_id = idx
        vid_info = self.vid_infos[vid]
        sample_range = range(len(vid_info['filenames']))
        valid_samples = []
        for i in sample_range:
            # check if the frame id is valid
            ref_idx = (vid, i)
            if i != frame_id and ref_idx in self.img_ids:
                valid_samples.append(ref_idx)
        assert len(valid_samples) > 0
        return random.choice(valid_samples)

    def prepare_train_flow_test_img(self, idx):

        # prepare a pair of image in a sequence
        vid, key_frame_id = idx[0]
        _, cur_frame_id = idx[1]
        vid_info = self.vid_infos[vid]

        # load image
        key_img = mmcv.imread(
            osp.join(self.img_prefix, vid_info['filenames_all'][key_frame_id]))
        cur_img = mmcv.imread(
            osp.join(self.img_prefix, vid_info['filenames_all'][cur_frame_id]))
        h_orig, w_orig, _ = key_img.shape
        basename = osp.basename(vid_info['filenames_all'][key_frame_id])

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        img_scale = random_scale(self.img_scales)  # sample a scale
        cur_img, img_shape, pad_shape, scale_factor = self.img_transform(
            cur_img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        if (type(scale_factor)) != float:
            scale_factor = tuple(scale_factor)
        cur_img = cur_img.copy()
        key_img, key_img_shape, _, ref_scale_factor = self.img_transform(
            key_img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        key_img = key_img.copy()

        # trans = torchvision.transforms.ToTensor()
        key_img = torch.from_numpy(key_img).cuda()
        cur_img = torch.from_numpy(cur_img).cuda()

        def resize(feat_map, size=(48, 64)):
            """Resize feature map to certain size."""
            key_feature = torch.nn.functional.interpolate(feat_map,
                                                          size,
                                                          mode='bilinear',
                                                          align_corners=True)
            return key_feature

        img_size = (384, 640)
        if key_img.shape[-2:] != img_size:
            key_img = resize(key_img.unsqueeze(0), img_size).squeeze(0)
            cur_img = resize(cur_img.unsqueeze(0), img_size).squeeze(0)

        key_feature_maps, _ = self.det_model.extract_feat(key_img.unsqueeze(0))
        cur_feature_maps, _ = self.det_model.extract_feat(cur_img.unsqueeze(0))

        key_feature_maps = [
            feat_map.squeeze(0) for feat_map in key_feature_maps
        ]
        cur_feature_maps = [
            feat_map.squeeze(0) for feat_map in cur_feature_maps
        ]

        data = dict(key_img=key_img,
                    cur_img=cur_img,
                    key_img_feats=key_feature_maps,
                    cur_img_feats=cur_feature_maps)
        return data

    def prepare_train_flow_img(self, idx):

        # prepare a pair of image in a sequence
        vid, key_frame_id = idx[0]
        _, cur_frame_id = idx[1]
        vid_info = self.vid_infos[vid]

        # load image
        key_img = mmcv.imread(
            osp.join(self.img_prefix, vid_info['filenames_all'][key_frame_id]))
        cur_img = mmcv.imread(
            osp.join(self.img_prefix, vid_info['filenames_all'][cur_frame_id]))
        h_orig, w_orig, _ = cur_img.shape
        basename = osp.basename(vid_info['filenames_all'][key_frame_id])

        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None
        ann_idx = vid_info['filenames'].index(
            vid_info['filenames_all'][cur_frame_id])
        ann = self.get_ann_info(vid, ann_idx)
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']

        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0:
            return None

        # extra augmentation
        if self.extra_aug is not None:
            cur_img, gt_bboxes, gt_labels = self.extra_aug(
                cur_img, gt_bboxes, gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False

        img_scales = [(1280, 720), (640, 360)]
        # img_scale = random_scale(self.img_scales)  # sample a scale
        cur_img, img_shape, pad_shape, scale_factor = self.img_transform(
            cur_img, img_scales[1], flip, keep_ratio=self.resize_keep_ratio)
        if (type(scale_factor)) != float:
            scale_factor = tuple(scale_factor)
        cur_img = cur_img.copy()
        key_img, key_img_shape, _, key_scale_factor = self.img_transform(
            key_img, img_scales[0], flip, keep_ratio=self.resize_keep_ratio)
        key_img = key_img.copy()
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack([proposals, scores
                                   ]) if scores is not None else proposals
        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)

        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            if w_orig > h_orig:
                h, w = img_shape[0], img_shape[1]
                _scale_factor = tuple([w, h, w, h])
            else:
                _scale_factor = scale_factor
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           _scale_factor, flip)

        ori_shape = (vid_info['height'], vid_info['width'], 3)
        img_meta = dict(ori_shape=ori_shape,
                        img_shape=img_shape,
                        pad_shape=pad_shape,
                        scale_factor=scale_factor,
                        is_first=(cur_frame_id == 0),
                        flip=flip)

        data = dict(
            img=DC(to_tensor(key_img), stack=True),
            ref_img=DC(to_tensor(cur_img), stack=True),
            img_meta=DC(img_meta, cpu_only=True),
            gt_bboxes=DC(to_tensor(gt_bboxes)),
            # ref_bboxes=DC(to_tensor(ref_bboxes))
        )
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        # if self.with_track:
        #     data['gt_pids'] = DC(to_tensor(gt_pids))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        data['train_flow'] = True

        if self.cuda:
            key_img_cuda = torch.from_numpy(key_img).cuda()
            cur_img_cuda = torch.from_numpy(cur_img).cuda()

            def resize(feat_map, size=(48, 64)):
                """Resize feature map to certain size."""
                key_feature = torch.nn.functional.interpolate(
                    feat_map, size, mode='bilinear', align_corners=True)
                return key_feature

            img_size = (384, 640)
            if key_img_cuda.shape[-2:] != img_size:
                key_img_cuda = resize(key_img_cuda.unsqueeze(0),
                                      img_size).squeeze(0)
                cur_img_cuda = resize(cur_img_cuda.unsqueeze(0),
                                      img_size).squeeze(0)

            key_feature_maps, _ = self.det_model.extract_feat(
                key_img_cuda.unsqueeze(0))
            cur_feature_maps, _ = self.det_model.extract_feat(
                cur_img_cuda.unsqueeze(0))

            key_feature_maps = [
                feat_map.squeeze(0) for feat_map in key_feature_maps
            ]
            cur_feature_maps = [
                feat_map.squeeze(0) for feat_map in cur_feature_maps
            ]

            data['key_feature_maps'] = key_feature_maps
            data['cur_feature_maps'] = cur_feature_maps

        return data

    def prepare_train_img(self, idx):
        # prepare a pair of image in a sequence
        vid, frame_id = idx
        vid_info = self.vid_infos[vid]
        # load image
        if self.is_flow or self.every_frame:
            img = mmcv.imread(
                osp.join(self.img_prefix, vid_info['filenames_all'][frame_id]))
        else:
            img = mmcv.imread(
                osp.join(self.img_prefix, vid_info['filenames'][frame_id]))
        h_orig, w_orig, _ = img.shape
        basename = osp.basename(vid_info['filenames'][frame_id])
        _, ref_frame_id = self.sample_ref(idx)
        ref_img = mmcv.imread(
            osp.join(self.img_prefix, vid_info['filenames'][ref_frame_id]))
        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None

        ann = self.get_ann_info(vid, frame_id)
        ref_ann = self.get_ann_info(vid, ref_frame_id)
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']
        ref_bboxes = ref_ann['bboxes']
        # obj ids attribute does not exist in current annotation
        # need to add it
        ref_ids = ref_ann['obj_ids']
        gt_ids = ann['obj_ids']
        # compute matching of reference frame with current frame
        # 0 denote there is no matching
        gt_pids = [ref_ids.index(i) + 1 if i in ref_ids else 0 for i in gt_ids]
        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0:
            return None

        # extra augmentation
        if self.extra_aug is not None:
            img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
                                                       gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        img_scale = random_scale(self.img_scales)  # sample a scale
        img, img_shape, pad_shape, scale_factor = self.img_transform(
            img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        if (type(scale_factor)) != float:
            scale_factor = tuple(scale_factor)
        img = img.copy()
        ref_img, ref_img_shape, _, ref_scale_factor = self.img_transform(
            ref_img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        ref_img = ref_img.copy()
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack([proposals, scores
                                   ]) if scores is not None else proposals
        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)
        ref_bboxes = self.bbox_transform(ref_bboxes, ref_img_shape,
                                         ref_scale_factor, flip)
        if self.aug_ref_bbox_param is not None:
            ref_bboxes = self.bbox_aug(ref_bboxes, ref_img_shape)
        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            if w_orig > h_orig:
                h, w = img_shape[0], img_shape[1]
                _scale_factor = tuple([w, h, w, h])
            else:
                _scale_factor = scale_factor
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           _scale_factor, flip)

        ori_shape = (vid_info['height'], vid_info['width'], 3)
        img_meta = dict(ori_shape=ori_shape,
                        img_shape=img_shape,
                        pad_shape=pad_shape,
                        scale_factor=scale_factor,
                        is_first=(frame_id == 0),
                        flip=flip)

        data = dict(img=DC(to_tensor(img), stack=True),
                    ref_img=DC(to_tensor(ref_img), stack=True),
                    img_meta=DC(img_meta, cpu_only=True),
                    gt_bboxes=DC(to_tensor(gt_bboxes)),
                    ref_bboxes=DC(to_tensor(ref_bboxes)))
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        if self.with_track:
            data['gt_pids'] = DC(to_tensor(gt_pids))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        return data

    def prepare_test_img(self, idx):
        """Prepare an image for testing (multi-scale and flipping)"""
        vid, frame_id = idx
        vid_info = self.vid_infos[vid]
        is_anno = True
        if self.every_frame:
            img = mmcv.imread(
                osp.join(self.img_prefix, vid_info['filenames_all'][frame_id]))
            is_anno = vid_info['filenames_all'][frame_id] in vid_info[
                'filenames']
        else:
            img = mmcv.imread(
                osp.join(self.img_prefix, vid_info['filenames'][frame_id]))
        proposal = None

        if self.every_frame:
            file_name = vid_info['filenames_all'][frame_id]
        else:
            file_name = vid_info['filenames'][frame_id]

        def prepare_single(img,
                           frame_id,
                           scale,
                           flip,
                           file_name,
                           proposal=None,
                           is_anno=True):
            _img, img_shape, pad_shape, scale_factor = self.img_transform(
                img, scale, flip, keep_ratio=self.resize_keep_ratio)
            _img = to_tensor(_img)
            _img_meta = dict(ori_shape=(vid_info['height'], vid_info['width'],
                                        3),
                             img_shape=img_shape,
                             pad_shape=pad_shape,
                             is_first=(frame_id == 0),
                             video_id=vid,
                             file_name=file_name,
                             frame_id=frame_id,
                             scale_factor=scale_factor,
                             flip=flip,
                             is_anno=is_anno)
            if proposal is not None:
                if proposal.shape[1] == 5:
                    score = proposal[:, 4, None]
                    proposal = proposal[:, :4]
                else:
                    score = None
                _proposal = self.bbox_transform(proposal, img_shape,
                                                scale_factor, flip)
                _proposal = np.hstack([_proposal, score
                                       ]) if score is not None else _proposal
                _proposal = to_tensor(_proposal)
            else:
                _proposal = None
            return _img, _img_meta, _proposal

        imgs = []
        img_metas = []
        proposals = []
        for scale in self.img_scales:
            _img, _img_meta, _proposal = prepare_single(
                img, frame_id, scale, False, file_name, proposal, is_anno)
            imgs.append(_img)
            img_metas.append(DC(_img_meta, cpu_only=True))
            proposals.append(_proposal)
            if self.flip_ratio > 0:
                _img, _img_meta, _proposal = prepare_single(
                    img, scale, True, file_name, proposal, is_anno)
                imgs.append(_img)
                img_metas.append(DC(_img_meta, cpu_only=True))
                proposals.append(_proposal)
        data = dict(img=imgs, img_meta=img_metas)
        return data

    def _parse_ann_info(self, ann_info, frame_id, with_mask=True):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_ids = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):
            # each ann is a list of masks
            # ann:
            # bbox: list of bboxes
            # segmentation: list of segmentation
            # category_id
            # area: list of area
            bbox = ann['bboxes'][frame_id]
            area = ann['areas'][frame_id]
            segm = ann['segmentations'][frame_id]
            if bbox is None: continue
            x1, y1, w, h = bbox
            if area <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann['iscrowd']:
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_ids.append(ann['id'])
                gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.ytvos.annToMask(ann, frame_id))
                mask_polys = [
                    p for p in segm if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   obj_ids=gt_ids,
                   bboxes_ignore=gt_bboxes_ignore)

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann
Ejemplo n.º 2
0
class YTVOSMavinbeCloneDataset(CustomDataset):
    CLASSES=('person','giant_panda','lizard','parrot','skateboard','sedan',
        'ape','dog','snake','monkey','hand','rabbit','duck','cat','cow','fish',
        'train','horse','turtle','bear','motorbike','giraffe','leopard',
        'fox','deer','owl','surfboard','airplane','truck','zebra','tiger',
        'elephant','snowboard','boat','shark','mouse','frog','eagle','earless_seal',
        'tennis_racket')

    def __init__(self,
                 ann_file,
                 img_prefix,
                 img_scale,
                 img_norm_cfg,
                 size_divisor=None,
                 proposal_file=None,
                 num_max_proposals=1000,
                 flip_ratio=0,
                 with_mask=True,
                 with_crowd=True,
                 with_label=True,
                 with_track=False,
                 extra_aug=None,
                 aug_ref_bbox_param=None,
                 resize_keep_ratio=True,
                 test_mode=False):
        self.frame_id_counter = 0
        self.cap = cv2.VideoCapture('/SipMask-VIS/data/abc2_cuted.mp4')

        # while True:
        #     grabbed, frame = self.cap.read()
        #     if grabbed is False:
        #         return
        # self.videoProvider = VideoStreamProvider('/SipMask-VIS/data/abc2_cuted.mp4')
        # grabbed, img = self.videoProvider.read()
        # if grabbed is False:
        #     print("Error: can't read first frame!")
        #     exit(1)

        # prefix of images path
        self.img_prefix = img_prefix

        # load annotations (and proposals)
        self.vid_infos = self.load_annotations(ann_file)
        img_ids = []
        for idx, vid_info in enumerate(self.vid_infos):
          for frame_id in range(len(vid_info['filenames'])):
            img_ids.append((idx, frame_id))
        self.img_ids = img_ids
        if proposal_file is not None:
            self.proposals = self.load_proposals(proposal_file)
        else:
            self.proposals = None
        # filter images with no annotation during training
        if not test_mode:
            valid_inds = [i for i, (v, f) in enumerate(self.img_ids)
                if len(self.get_ann_info(v, f)['bboxes'])]
            self.img_ids = [self.img_ids[i] for i in valid_inds]

        # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
        self.img_scales = img_scale if isinstance(img_scale,
                                                  list) else [img_scale]
        assert mmcv.is_list_of(self.img_scales, tuple)
        # normalization configs
        self.img_norm_cfg = img_norm_cfg

        # max proposals per image
        self.num_max_proposals = num_max_proposals
        # flip ratio
        self.flip_ratio = flip_ratio
        assert flip_ratio >= 0 and flip_ratio <= 1
        # padding border to ensure the image size can be divided by
        # size_divisor (used for FPN)
        self.size_divisor = size_divisor

        # with mask or not (reserved field, takes no effect)
        self.with_mask = with_mask
        # some datasets provide bbox annotations as ignore/crowd/difficult,
        # if `with_crowd` is True, then these info is returned.
        self.with_crowd = with_crowd
        # with label is False for RPN
        self.with_label = with_label
        self.with_track = with_track
        # params for augmenting bbox in the reference frame
        self.aug_ref_bbox_param = aug_ref_bbox_param
        # in test mode or not
        self.test_mode = test_mode

        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()
        # transforms
        self.img_transform = ImageTransform(
            size_divisor=self.size_divisor, **self.img_norm_cfg)
        self.bbox_transform = BboxTransform()
        self.mask_transform = MaskTransform()
        self.numpy2tensor = Numpy2Tensor()

        # if use extra augmentation
        if extra_aug is not None:
            self.extra_aug = ExtraAugmentation(**extra_aug)
        else:
            self.extra_aug = None

        # image rescale if keep ratio
        self.resize_keep_ratio = resize_keep_ratio

    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(self.img_ids[idx])
        data = self.prepare_train_img(self.img_ids[idx])
        return data
    
    def load_annotations(self, ann_file):
        self.ytvos = YTVOS(ann_file)
        self.cat_ids = self.ytvos.getCatIds()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.vid_ids = self.ytvos.getVidIds()
        vid_infos = []
        for i in self.vid_ids:
            info = self.ytvos.loadVids([i])[0]
            info['filenames'] = info['file_names']
            vid_infos.append(info)
        return vid_infos

    def get_ann_info(self, idx, frame_id):
        vid_id = self.vid_infos[idx]['id']
        ann_ids = self.ytvos.getAnnIds(vidIds=[vid_id])
        ann_info = self.ytvos.loadAnns(ann_ids)
        return self._parse_ann_info(ann_info, frame_id)

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            vid_id, _ = self.img_ids[i]
            vid_info = self.vid_infos[vid_id]
            if vid_info['width'] / vid_info['height'] > 1:
                self.flag[i] = 1
    def bbox_aug(self, bbox, img_size):
        assert self.aug_ref_bbox_param is not None
        center_off = self.aug_ref_bbox_param[0]
        size_perturb = self.aug_ref_bbox_param[1]
        
        n_bb = bbox.shape[0]
        # bbox center offset
        center_offs = (2*np.random.rand(n_bb, 2) - 1) * center_off
        # bbox resize ratios
        resize_ratios = (2*np.random.rand(n_bb, 2) - 1) * size_perturb + 1
        # bbox: x1, y1, x2, y2
        centers = (bbox[:,:2]+ bbox[:,2:])/2.
        sizes = bbox[:,2:] - bbox[:,:2]
        new_centers = centers + center_offs * sizes
        new_sizes = sizes * resize_ratios
        new_x1y1 = new_centers - new_sizes/2.
        new_x2y2 = new_centers + new_sizes/2.
        c_min = [0,0]
        c_max = [img_size[1], img_size[0]]
        new_x1y1 = np.clip(new_x1y1, c_min, c_max)
        new_x2y2 = np.clip(new_x2y2, c_min, c_max)
        bbox = np.hstack((new_x1y1,new_x2y2)).astype(np.float32)
        return bbox

    def sample_ref(self, idx):
        # sample another frame in the same sequence as reference
        vid, frame_id = idx
        vid_info = self.vid_infos[vid]
        sample_range = range(len(vid_info['filenames']))
        valid_samples = []
        for i in sample_range:
          # check if the frame id is valid
          ref_idx = (vid, i)
          if i != frame_id and ref_idx in self.img_ids:
              valid_samples.append(ref_idx)
        assert len(valid_samples) > 0
        return random.choice(valid_samples)

    def prepare_train_img(self, idx):
        # prepare a pair of image in a sequence
        vid,  frame_id = idx
        vid_info = self.vid_infos[vid]
        # load image
        img = mmcv.imread(osp.join(self.img_prefix, vid_info['filenames'][frame_id]))
        basename = osp.basename(vid_info['filenames'][frame_id])
        _, ref_frame_id = self.sample_ref(idx)
        ref_img = mmcv.imread(osp.join(self.img_prefix, vid_info['filenames'][ref_frame_id]))
        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None

        ann = self.get_ann_info(vid, frame_id)
        ref_ann = self.get_ann_info(vid, ref_frame_id)
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']
        ref_bboxes = ref_ann['bboxes']
        # obj ids attribute does not exist in current annotation
        # need to add it
        ref_ids = ref_ann['obj_ids']
        gt_ids = ann['obj_ids']
        # compute matching of reference frame with current frame
        # 0 denote there is no matching
        gt_pids = [ref_ids.index(i)+1 if i in ref_ids else 0 for i in gt_ids]
        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0:
            return None

        # extra augmentation
        if self.extra_aug is not None:
            img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
                                                       gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        img_scale = random_scale(self.img_scales)  # sample a scale
        # print(img_scale)
        img, img_shape, pad_shape, scale_factor = self.img_transform(
            img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        img = img.copy()
        ref_img, ref_img_shape, _, ref_scale_factor = self.img_transform(
            ref_img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        ref_img = ref_img.copy()
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack(
                [proposals, scores]) if scores is not None else proposals
        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)
        ref_bboxes = self.bbox_transform(ref_bboxes, ref_img_shape, ref_scale_factor,
                                          flip)
        if self.aug_ref_bbox_param is not None:
            ref_bboxes = self.bbox_aug(ref_bboxes, ref_img_shape)
        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           scale_factor, flip)

        ori_shape = (vid_info['height'], vid_info['width'], 3)
        img_meta = dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=flip)

        data = dict(
            img=DC(to_tensor(img), stack=True),
            ref_img=DC(to_tensor(ref_img), stack=True),
            img_meta=DC(img_meta, cpu_only=True),
            gt_bboxes=DC(to_tensor(gt_bboxes)),
            ref_bboxes = DC(to_tensor(ref_bboxes))
        )
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        if self.with_track:
            data['gt_pids'] = DC(to_tensor(gt_pids))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        return data
    def prepare_test_img(self, idx):
        """Prepare an image for testing (multi-scale and flipping)"""
        frame_id = self.frame_id_counter
        vid = 0
        #vid, frame_id = idx
        vid_info = self.vid_infos[vid]
        proposal = None
        
        #grabbed, img = self.videoProvider.read()
        grabbed, img = self.cap.read()
        if grabbed is False:
            return

        def prepare_single(img, frame_id, scale, flip, proposal=None):
            ori_shape = img.shape
            _img, img_shape, pad_shape, scale_factor = self.img_transform(
                img, scale, flip, keep_ratio=self.resize_keep_ratio)
            _img = to_tensor(_img)
            ori_shape = img.shape
            _img_meta = dict(
                ori_shape=ori_shape,
                img_shape=img_shape,
                pad_shape=pad_shape,
                is_first=(frame_id == 0),
                video_id=vid,
                frame_id =frame_id,
                scale_factor=scale_factor,
                flip=flip)
            if proposal is not None:
                if proposal.shape[1] == 5:
                    score = proposal[:, 4, None]
                    proposal = proposal[:, :4]
                else:
                    score = None
                _proposal = self.bbox_transform(proposal, img_shape,
                                                scale_factor, flip)
                _proposal = np.hstack(
                    [_proposal, score]) if score is not None else _proposal
                _proposal = to_tensor(_proposal)
            else:
                _proposal = None
            return _img, _img_meta, _proposal

        imgs = []
        img_metas = []
        proposals = []
        for scale in self.img_scales:
            _img, _img_meta, _proposal = prepare_single(
                img, frame_id, scale, False, proposal)
            imgs.append(_img)
            img_metas.append(DC(_img_meta, cpu_only=True))
            proposals.append(_proposal)
            if self.flip_ratio > 0:
                _img, _img_meta, _proposal = prepare_single(
                    img, scale, True, proposal)
                imgs.append(_img)
                img_metas.append(DC(_img_meta, cpu_only=True))
                proposals.append(_proposal)
        data = dict(img=imgs, img_meta=img_metas)
        self.frame_id_counter += 1
        return data

    def _parse_ann_info(self, ann_info, frame_id, with_mask=True):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_ids = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):
            # each ann is a list of masks
            # ann:
            # bbox: list of bboxes
            # segmentation: list of segmentation
            # category_id
            # area: list of area
            bbox = ann['bboxes'][frame_id]
            area = ann['areas'][frame_id]
            segm = ann['segmentations'][frame_id]
            if bbox is None: continue
            x1, y1, w, h = bbox
            if area <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann['iscrowd']:
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_ids.append(ann['id'])
                gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.ytvos.annToMask(ann, frame_id))
                mask_polys = [
                    p for p in segm if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(
            bboxes=gt_bboxes, labels=gt_labels, obj_ids=gt_ids, bboxes_ignore=gt_bboxes_ignore)

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann
Ejemplo n.º 3
0
class YTVOSDataset(CustomDataset):
    CLASSES = ('person', 'giant_panda', 'lizard', 'parrot', 'skateboard',
               'sedan', 'ape', 'dog', 'snake', 'monkey', 'hand', 'rabbit',
               'duck', 'cat', 'cow', 'fish', 'train', 'horse', 'turtle',
               'bear', 'motorbike', 'giraffe', 'leopard', 'fox', 'deer', 'owl',
               'surfboard', 'airplane', 'truck', 'zebra', 'tiger', 'elephant',
               'snowboard', 'boat', 'shark', 'mouse', 'frog', 'eagle',
               'earless_seal', 'tennis_racket')

    def __init__(self,
                 ann_file,
                 pipeline,
                 data_root=None,
                 img_prefix='',
                 of_prefix=None,
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,
                 filter_empty_gt=True,
                 seq_len=0,
                 step=1):

        # prefix of images path
        self.ann_file = ann_file
        self.data_root = data_root
        self.img_prefix = img_prefix
        self.of_prefix = of_prefix
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt

        self.seq_len = seq_len

        # join paths if data_root is specified
        if self.data_root is not None:
            if not osp.isabs(self.ann_file):
                self.ann_file = osp.join(self.data_root, self.ann_file)
            if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
                self.img_prefix = osp.join(self.data_root, self.img_prefix)
            if not (self.of_prefix is None or osp.isabs(self.of_prefix)):
                self.of_prefix = osp.join(self.data_root, self.of_prefix)
            if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
                self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
            if not (self.proposal_file is None
                    or osp.isabs(self.proposal_file)):
                self.proposal_file = osp.join(self.data_root,
                                              self.proposal_file)

        # load annotations (and proposals)
        self.vid_infos = self.load_annotations(ann_file)

        self.sample_ids = []
        self._frame_ids = {}
        for vid_id, vid_info in self.vid_infos.items():

            #video_img_ids = []
            video_sample_ids = []
            for frame_id in range(len(vid_info['filenames'])):
                idx = (vid_id, frame_id)
                if test_mode or len(self.get_ann_info(idx)['masks']):
                    video_sample_ids.append(idx)
            if len(video_sample_ids) >= seq_len:
                self._frame_ids[vid_id] = video_sample_ids
                if seq_len > 0:
                    self.sample_ids = self.sample_ids + video_sample_ids[
                        0::step]
                    #if test_mode:
                    #    self.sample_ids = self.sample_ids+ [video_sample_ids[-1]]

                else:
                    self.sample_ids = self.sample_ids + [video_sample_ids[-1]]

        if proposal_file is not None:
            self.proposals = self.load_proposals(proposal_file)
        else:
            self.proposals = None

        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()
        # processing pipeline
        self.pipeline = Compose(pipeline)

    def __len__(self):
        return len(self.sample_ids)

    def _rand_another(self, idx):
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(self.sample_ids[idx])
        while True:
            data = self.prepare_train_img(self.sample_ids[idx])
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

    def load_annotations(self, ann_file):
        self.ytvos = YTVOS(ann_file)
        self.cat_ids = self.ytvos.getCatIds()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.vid_ids = self.ytvos.getVidIds()
        vid_infos = {}
        for i in self.vid_ids:
            info = self.ytvos.loadVids([i])[0]
            info['filenames'] = info['file_names']
            vid_infos[i] = info
        return vid_infos

    def get_ann_info(self, idx):
        vid_id, frame_id = idx
        #vid_id = self.vid_infos[idx]['id']
        ann_ids = self.ytvos.getAnnIds(vidIds=[vid_id])
        ann_info = self.ytvos.loadAnns(ann_ids)
        return self._parse_ann_info(self.get_image_info(idx), ann_info,
                                    frame_id)

    def get_image_info(self, idx):
        vid_id, frame_id = idx
        vid_info = self.vid_infos[vid_id]
        return dict(filename=vid_info['filenames'][frame_id],
                    height=vid_info['height'],
                    width=vid_info['width'],
                    video_id=vid_id,
                    frame_id=frame_id)

    def pre_pipeline(self, results, idx, prev_results=None):
        results['img_prefix'] = self.img_prefix
        results['of_prefix'] = self.of_prefix
        results['seg_prefix'] = self.seg_prefix
        results['proposal_file'] = self.proposal_file
        results['bbox_fields'] = []
        results['mask_fields'] = []
        results['seg_fields'] = []
        video_id, frame_id = idx
        results['video_id'] = video_id
        results['frame_id'] = frame_id
        if prev_results is not None:
            results.update(prev_results)

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.
        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self.sample_ids), dtype=np.uint8)
        for vid_id, _ in self.sample_ids:
            vid_info = self.vid_infos[vid_id]
            if vid_info['width'] / vid_info['height'] > 1:
                self.flag[vid_id] = 1

    def sample_ref_seq(self, idx):
        vid, fid = idx
        ref_frame_ids = self._frame_ids[vid]
        frame_index = ref_frame_ids.index(idx)
        ref_frame_ids_1 = ref_frame_ids[max(frame_index -
                                            (self.seq_len), 0):frame_index]
        ref_frame_ids_2 = ref_frame_ids[frame_index + 1:frame_index +
                                        (self.seq_len)]
        ref_frame_ids = ref_frame_ids_1 + ref_frame_ids_2
        if len(ref_frame_ids) < self.seq_len - 1:
            print('frame_index', frame_index, frame_index - self.seq_len,
                  frame_index + self.seq_len, ref_frame_ids)
            return None, 0
        ref_frame_ids = random.sample(ref_frame_ids, self.seq_len - 1)
        ref_frame_ids.append((vid, fid))
        ref_frame_ids.sort()
        frame_index = ref_frame_ids.index(idx)
        return ref_frame_ids, frame_index

    def test_ref_seq(self, idx):
        vid, fid = idx
        ref_frame_ids = self._frame_ids[vid]
        frame_index = ref_frame_ids.index(idx)
        ref_frame_ids = ref_frame_ids[frame_index:frame_index + self.seq_len]
        if len(ref_frame_ids) < self.seq_len:
            print('frame_index', frame_index, frame_index - self.seq_len,
                  frame_index + self.seq_len, ref_frame_ids)
            ref_frame_ids = None
        else:
            ref_frame_ids.sort()
        return ref_frame_ids

    def prepare_train_img(self, idx):
        # prepare a pair of image in a sequence

        seq = []
        flip_keys = ['flip', 'flip_direction']
        prev_results = None
        samples, ref_frame_index = self.sample_ref_seq(idx)
        if samples is None:
            return None

        for sample_idx in samples:

            img_info = self.get_image_info(sample_idx)
            ann_info = self.get_ann_info(sample_idx)
            results = dict(img_info=img_info, ann_info=ann_info)
            if self.proposals is not None:
                results['proposals'] = self.proposals[idx]
            self.pre_pipeline(results, sample_idx, prev_results=prev_results)
            seq.append(self.pipeline(results))
            if prev_results is None:
                prev_results = {
                    flip_key: results[flip_key]
                    for flip_key in flip_keys
                }
        return dict(img=seq[0]['img'],
                    img_meta=seq[0]['img_meta'],
                    inp_seq=seq,
                    ref_frame_index=ref_frame_index)

    def prepare_test_img(self, idx):

        vid, fid = idx
        ref_frame_ids = self._frame_ids[vid]
        ref_frame_ids.sort()
        #frame_idx = ref_frame_ids.index((str(vid), fid))
        frame_idx = ref_frame_ids.index((vid, fid))

        ref_frame_ids = ref_frame_ids[frame_idx:frame_idx + self.seq_len]
        seq = []
        for sample_idx in ref_frame_ids:

            img_info = self.get_image_info(sample_idx)
            results = dict(img_info=img_info)
            if self.proposals is not None:
                results['proposals'] = self.proposals[idx]
            self.pre_pipeline(results, sample_idx)
            seq.append(self.pipeline(results))

        return dict(img=seq[0]['img'],
                    img_meta=seq[0]['img_meta'],
                    inp_seq=seq)

    def _parse_ann_info(self, img_info, ann_info, frame_id):
        """Parse bbox and mask annotation.
        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.
        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, seg_map. "masks" are raw annotations and not
                decoded into binary masks.
        """
        gt_bboxes = []
        gt_labels = []
        gt_ids = []
        gt_bboxes_ignore = []
        gt_masks_ann = []
        for i, ann in enumerate(ann_info):
            # each ann is a list of masks
            # ann:
            # bbox: list of bboxes
            # segmentation: list of segmentation
            # category_id
            # area: list of area
            bbox = ann['bboxes'][frame_id]
            area = ann['areas'][frame_id]
            segm = ann['segmentations'][frame_id]
            if bbox is None: continue
            x1, y1, w, h = bbox
            if area <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann.get('iscrowd', False):
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_ids.append(ann['id'])
                gt_labels.append(self.cat2label[ann['category_id']])
                gt_masks_ann.append(ann['segmentations'][frame_id])
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
            gt_ids = np.array(gt_ids, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)
            gt_ids = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        seg_map = img_info['filename'].replace('jpg', 'png')

        ann = dict(bboxes=gt_bboxes,
                   labels=gt_labels,
                   inst_ids=gt_ids,
                   bboxes_ignore=gt_bboxes_ignore,
                   masks=gt_masks_ann,
                   seg_map=seg_map)

        return ann

    def _segm2json(self, results_list):
        """Dump the detection results to a json file.

        There are 3 types of results: proposals, bbox predictions, mask
        predictions, and they have different data types. This method will
        automatically recognize the type, and dump them to json files.

        Args:
            results (list[list | tuple | ndarray]): Testing results of the
                dataset.
            outfile_prefix (str): The filename prefix of the json files. If the
                prefix is "somepath/xxx", the json files will be named
                "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
                "somepath/xxx.proposal.json".

        Returns:
            dict[str: str]: Possible keys are "bbox", "segm", "proposal", and
                values are corresponding filenames.
        """

        results = []
        for rl_det, rl_seg in results_list:

            for d, s in zip(rl_det, rl_seg):
                results.append((d, s))

        json_results = []
        vid_objs = {}
        res_idx = 0
        for idx in range(len(self)):
            # assume results is ordered

            vid_id, fr_id = self.sample_ids[idx]
            ref_frame_ids = self._frame_ids[vid_id]
            vid_len = len(ref_frame_ids)
            for frame_id in range(fr_id, min(fr_id + self.seq_len, vid_len)):

                is_last = frame_id == vid_len - 1

                det, seg = results[res_idx]
                res_idx += 1
                for obj_id in det:
                    bbox = det[obj_id]['bbox']

                    if obj_id in seg:
                        segm = seg[obj_id]
                        label = det[obj_id]['label']

                        if obj_id not in vid_objs:
                            vid_objs[obj_id] = {
                                'scores': [],
                                'cats': [],
                                'segms': {}
                            }
                        vid_objs[obj_id]['scores'].append(bbox[4])
                        vid_objs[obj_id]['cats'].append(label)
                        segm['counts'] = segm['counts'].decode()
                        vid_objs[obj_id]['segms'][frame_id] = segm
                if is_last:
                    # store results of  the current video
                    for obj_id, obj in vid_objs.items():
                        data = dict()

                        data['video_id'] = vid_id
                        data['score'] = np.array(obj['scores']).mean().item()
                        # majority voting for sequence category
                        data['category_id'] = np.bincount(np.array(
                            obj['cats'])).argmax().item() + 1
                        vid_seg = []
                        for fid in range(frame_id + 1):
                            if fid in obj['segms']:
                                vid_seg.append(obj['segms'][fid])
                            else:
                                vid_seg.append(None)
                        data['segmentations'] = vid_seg
                        first = False
                        json_results.append(data)
                    vid_objs = {}

        return [], json_results

    def results2json(self, results, outfile_prefix):
        """Dump the detection results to a json file.

        There are 3 types of results: proposals, bbox predictions, mask
        predictions, and they have different data types. This method will
        automatically recognize the type, and dump them to json files.

        Args:
            results (list[list | tuple | ndarray]): Testing results of the
                dataset.
            outfile_prefix (str): The filename prefix of the json files. If the
                prefix is "somepath/xxx", the json files will be named
                "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
                "somepath/xxx.proposal.json".

        Returns:
            dict[str: str]: Possible keys are "bbox", "segm", "proposal", and
                values are corresponding filenames.
        """
        result_files = dict()
        if isinstance(results[0], list):
            json_results = self._det2json(results)
            result_files['bbox'] = '{}.{}.json'.format(outfile_prefix, 'bbox')
            result_files['proposal'] = '{}.{}.json'.format(
                outfile_prefix, 'bbox')
            mmcv.dump(json_results, result_files['bbox'])
        elif isinstance(results[0], tuple):
            json_results = self._segm2json(results)
            result_files['bbox'] = '{}.{}.json'.format(outfile_prefix, 'bbox')
            result_files['proposal'] = '{}.{}.json'.format(
                outfile_prefix, 'bbox')
            result_files['segm'] = '{}.{}.json'.format(outfile_prefix, 'segm')
            mmcv.dump(json_results[0], result_files['bbox'])
            mmcv.dump(json_results[1], result_files['segm'])
        elif isinstance(results[0], np.ndarray):
            json_results = self._proposal2json(results)
            result_files['proposal'] = '{}.{}.json'.format(
                outfile_prefix, 'proposal')
            mmcv.dump(json_results, result_files['proposal'])
        else:
            raise TypeError('invalid type of results')
        return result_files

    def format_results(self, results, jsonfile_prefix=None, **kwargs):
        """Format the results to json (standard format for COCO evaluation).

        Args:
            results (list): Testing results of the dataset.
            jsonfile_prefix (str | None): The prefix of json files. It includes
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.

        Returns:
            tuple: (result_files, tmp_dir), result_files is a dict containing
                the json filepaths, tmp_dir is the temporal directory created
                for saving json files when jsonfile_prefix is not specified.
        """
        assert isinstance(results, list), 'results must be a list'
        '''
        assert len(results) == len(self), (
            'The length of results is not equal to the dataset len: {} != {}'.
            format(len(results), len(self)))
        '''

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None
        result_files = self.results2json(results, jsonfile_prefix)
        return result_files, tmp_dir

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 classwise=False,
                 proposal_nums=(100, 300, 1000),
                 iou_thrs=np.arange(0.5, 0.96, 0.05)):
        """Evaluation in COCO protocol.

        Args:
            results (list): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None): The prefix of json files. It includes
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.
            classwise (bool): Whether to evaluating the AP for each class.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thrs (Sequence[float]): IoU threshold used for evaluating
                recalls. If set to a list, the average recall of all IoUs will
                also be computed. Default: 0.5.

        Returns:
            dict[str: float]
        """

        metrics = metric if isinstance(metric, list) else [metric]
        allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError('metric {} is not supported'.format(metric))

        result_files, tmp_dir = self.format_results(results, jsonfile_prefix)

        eval_results = {}
        ytvosGt = self.ytvos
        for metric in metrics:
            msg = 'Evaluating {}...'.format(metric)
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            if metric == 'proposal_fast':
                ar = self.fast_eval_recall(results,
                                           proposal_nums,
                                           iou_thrs,
                                           logger='silent')
                log_msg = []
                for i, num in enumerate(proposal_nums):
                    eval_results['AR@{}'.format(num)] = ar[i]
                    log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            if metric not in result_files:
                raise KeyError('{} is not in results'.format(metric))
            try:
                ytvosDt = ytvosGt.loadRes(result_files[metric])
            except IndexError:
                print_log('The testing results of the whole dataset is empty.',
                          logger=logger,
                          level=logging.ERROR)
                break

            iou_type = 'bbox' if metric == 'proposal' else metric
            ytvosEval = YTVOSeval(ytvosGt, ytvosDt, iou_type)
            vid_ids = self.ytvos.getVidIds()
            ytvosEval.params.vidIds = vid_ids
            if metric == 'proposal':
                ytvosEval.params.useCats = 0
                ytvosEval.params.maxDets = list(proposal_nums)
                ytvosEval.evaluate()
                ytvosEval.accumulate()
                ytvosEval.summarize()
                metric_items = [
                    'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', 'AR_m@1000',
                    'AR_l@1000'
                ]
                for i, item in enumerate(metric_items):
                    val = float('{:.3f}'.format(ytvosEval.stats[i + 6]))
                    eval_results[item] = val
            else:
                ytvosEval.evaluate()
                ytvosEval.accumulate()
                ytvosEval.summarize()
                if classwise:  # Compute per-category AP
                    pass  # TODO
                metric_items = [
                    'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
                ]
                for i in range(len(metric_items)):
                    key = '{}_{}'.format(metric, metric_items[i])
                    val = float('{:.3f}'.format(ytvosEval.stats[i]))
                    eval_results[key] = val
                eval_results['{}_mAP_copypaste'.format(metric)] = (
                    '{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
                    '{ap[4]:.3f} {ap[5]:.3f}').format(ap=ytvosEval.stats[:6])
        if tmp_dir is not None:
            tmp_dir.cleanup()
        return eval_results