Beispiel #1
0
class LvisSaver(object):
    def __init__(self, model, path):
        self.lvis = LVIS(
            '/scratch/users/zzweng/datasets/lvis/lvis_v0.5_val.json')
        self.model = model
        self.model.eval()
        self.path = path

    def save(self, k=50):
        # split the files into 50 chunks and process each concurrently
        #         self._save_part(0, 5)

        rng = np.linspace(0, 5000, k + 1, dtype=int)
        args = list(zip(rng[:-1], rng[1:]))
        with multiprocessing.Pool(processes=6) as pool:
            results = pool.starmap(self._save_part, args)
        print('Done')

    def _save_part(self, start, end):
        print('Getting features from {} to {}'.format(start, end))
        img_ids = self.lvis.get_img_ids()[start:end]
        feature_y = []
        feature_x = []
        for img_id in tqdm(img_ids):
            img = self.lvis.load_imgs([img_id])[0]
            I = io.imread(img['coco_url'])
            if len(I.shape) == 2: continue

            for ann_id in self.lvis.get_ann_ids(img_ids=[img_id]):
                ann = self.lvis.load_anns([ann_id])[0]
                b = np.array(ann['bbox']).astype(np.int)
                try:
                    #                     import ipdb as pdb
                    #                     pdb.set_trace()
                    I_masked = I * np.expand_dims(self.lvis.ann_to_mask(ann),
                                                  2)
                    patch = I_masked[b[1]:b[1] + b[3],
                                     b[0]:b[0] + b[2], :] / 255.
                    patch = cv2.resize(patch, (224, 224))
                    patch_tensor = torch.tensor(patch).float()
                    feat = self.model(
                        patch_tensor.view(1, *patch_tensor.shape).permute(
                            0, 3, 1, 2))[1].detach().numpy().flatten()
                    feature_x.append(feat)
                    feature_y.append(ann['category_id'])
                except:
                    print('skipping anns', b)

        feature_x_arr = np.stack(feature_x)
        feature_y_arr = np.array(feature_y)
        print(feature_x_arr.shape, feature_y_arr.shape)

        np.save(
            os.path.join(self.path, 'val_feats_{}_{}_x.npy'.format(start,
                                                                   end)),
            feature_x_arr)
        np.save(
            os.path.join(self.path, 'val_feats_{}_{}_y.npy'.format(start,
                                                                   end)),
            feature_y_arr)
Beispiel #2
0
class LvisDataSet(CustomDataset):
    def __init__(self, samples_per_cls_file=None, **kwargs):
        self.samples_per_cls_file = samples_per_cls_file
        super(LvisDataSet, self).__init__(**kwargs)

    def load_annotations(self, ann_file):
        self.lvis = LVIS(ann_file)
        self.cat_ids = self.lvis.get_cat_ids()
        self.cat2label = {
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.CLASSES = [_ for _ in self.cat_ids]
        self.cat_instance_count = [_ for _ in self.cat_ids]
        self.cat_image_count = [_ for _ in self.cat_ids]
        img_count_lbl = ["r", "c", "f"]
        self.freq_groups = [[] for _ in img_count_lbl]
        self.cat_group_idxs = [_ for _ in self.cat_ids]
        freq_group_count = {'f': 0, 'cf': 0, 'rcf': 0}
        self.cat_fake_idxs = {
            'f': [-1 for _ in self.cat_ids],
            'cf': [-1 for _ in self.cat_ids],
            'rcf': [-1 for _ in self.cat_ids]
        }
        self.freq_group_dict = {'rcf': (0, 1, 2), 'cf': (1, 2), 'f': (2, )}
        for value in self.lvis.cats.values():
            idx = value['id'] - 1
            self.CLASSES[idx] = value['name']
            self.cat_instance_count[idx] = value['instance_count']
            self.cat_image_count[idx] = value['image_count']
            group_idx = img_count_lbl.index(value["frequency"])
            self.freq_groups[group_idx].append(idx + 1)
            self.cat_group_idxs[idx] = group_idx
            if group_idx == 0:  # rare
                freq_group_count['rcf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
            if group_idx == 1:  # common
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
            elif group_idx == 2:  # freq
                freq_group_count['rcf'] += 1
                freq_group_count['cf'] += 1
                freq_group_count['f'] += 1
                self.cat_fake_idxs['rcf'][idx] = freq_group_count['rcf']
                self.cat_fake_idxs['cf'][idx] = freq_group_count['cf']
                self.cat_fake_idxs['f'][idx] = freq_group_count['f']

        if self.samples_per_cls_file is not None:
            with open(self.samples_per_cls_file, 'w') as file:
                file.writelines(str(x) + '\n' for x in self.cat_instance_count)

        self.img_ids = self.lvis.get_img_ids()
        img_infos = []
        for i in self.img_ids:
            info = self.lvis.load_imgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx, freq_groups=('rcf', )):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
        ann_info = self.lvis.load_anns(ann_ids)
        return self._parse_ann_info(ann_info,
                                    self.with_mask,
                                    freq_groups=freq_groups)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.lvis.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, ann_info, with_mask=True, freq_groups=('rcf', )):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        assert isinstance(freq_groups, tuple)
        gt_valid_idxs = {name: [] for name in freq_groups}
        gt_count = 0
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue

            for name in freq_groups:
                if self.cat_group_idxs[ann['category_id'] -
                                       1] in self.freq_group_dict[name]:
                    gt_valid_idxs[name].append(gt_count)
            gt_count += 1

            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            gt_bboxes.append(bbox)
            gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.lvis.ann_to_mask(ann))
                mask_polys = [
                    p for p in ann['segmentation'] if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(
            bboxes=gt_bboxes,
            labels=gt_labels,
            bboxes_ignore=gt_bboxes_ignore,
            gt_valid_idxs=gt_valid_idxs  # add gt_valid_idxs
        )

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposals = self.proposals[idx][:self.num_max_proposals]
            # TODO: Handle empty proposals properly. Currently images with
            # no proposals are just ignored, but they can be used for
            # training in concept.
            if len(proposals) == 0:
                return None
            if not (proposals.shape[1] == 4 or proposals.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposals.shape))
            if proposals.shape[1] == 5:
                scores = proposals[:, 4, None]
                proposals = proposals[:, :4]
            else:
                scores = None

        ann = self.get_ann_info(idx, freq_groups=('rcf', 'cf', 'f'))
        gt_bboxes = ann['bboxes']
        gt_labels = ann['labels']
        gt_valid_idxs = ann['gt_valid_idxs']
        if self.with_crowd:
            gt_bboxes_ignore = ann['bboxes_ignore']

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0 and self.skip_img_without_anno:
            warnings.warn('Skip the image "%s" that has no valid gt bbox' %
                          osp.join(self.img_prefix, img_info['filename']))
            return None

        # extra augmentation
        if self.extra_aug is not None:
            img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
                                                       gt_labels)

        # apply transforms
        flip = True if np.random.rand() < self.flip_ratio else False
        # randomly sample a scale
        img_scale = random_scale(self.img_scales, self.multiscale_mode)
        img, img_shape, pad_shape, scale_factor = self.img_transform(
            img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
        img = img.copy()
        if self.with_seg:
            gt_seg = mmcv_custom.imread(osp.join(
                self.seg_prefix, img_info['filename'].replace('jpg', 'png')),
                                        flag='unchanged')
            gt_seg = self.seg_transform(gt_seg.squeeze(), img_scale, flip)
            gt_seg = mmcv.imrescale(gt_seg,
                                    self.seg_scale_factor,
                                    interpolation='nearest')
            gt_seg = gt_seg[None, ...]
        if self.proposals is not None:
            proposals = self.bbox_transform(proposals, img_shape, scale_factor,
                                            flip)
            proposals = np.hstack([proposals, scores
                                   ]) if scores is not None else proposals

        gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
                                        flip)

        if self.with_crowd:
            gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
                                                   scale_factor, flip)
        if self.with_mask:
            gt_masks = self.mask_transform(ann['masks'], pad_shape,
                                           scale_factor, flip)

        ori_shape = (img_info['height'], img_info['width'], 3)
        not_exhaustive_category_ids = img_info['not_exhaustive_category_ids']
        neg_category_ids = img_info['neg_category_ids']
        img_meta = dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=flip,
            not_exhaustive_category_ids=not_exhaustive_category_ids,
            neg_category_ids=neg_category_ids,
            # cat_group_idxs=self.cat_group_idxs,
            cat_instance_count=self.cat_instance_count,
            freq_groups=self.freq_groups,
            cat_fake_idxs=self.cat_fake_idxs,
            freq_group_dict=self.freq_group_dict,
        )

        data = dict(
            img=DC(to_tensor(img), stack=True),
            img_meta=DC(img_meta, cpu_only=True),
            gt_bboxes=DC(to_tensor(gt_bboxes)),
            gt_valid_idxs=DC(gt_valid_idxs, cpu_only=True),
        )
        if self.proposals is not None:
            data['proposals'] = DC(to_tensor(proposals))
        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
        if self.with_crowd:
            data['gt_bboxes_ignore'] = DC(to_tensor(gt_bboxes_ignore))
        if self.with_mask:
            data['gt_masks'] = DC(gt_masks, cpu_only=True)
        if self.with_seg:
            data['gt_semantic_seg'] = DC(to_tensor(gt_seg), stack=True)
        return data

    def prepare_test_img(self, idx):
        """Prepare an image for testing (multi-scale and flipping)"""
        img_info = self.img_infos[idx]
        # load image
        if 'COCO' in img_info['filename']:
            img = mmcv_custom.imread(
                osp.join(
                    self.img_prefix, img_info['filename']
                    [img_info['filename'].find('COCO_val2014_') +
                     len('COCO_val2014_'):]))
        else:
            img = mmcv_custom.imread(
                osp.join(self.img_prefix, img_info['filename']))
        # corruption
        if self.corruption is not None:
            img = corrupt(img,
                          severity=self.corruption_severity,
                          corruption_name=self.corruption)
        # load proposals if necessary
        if self.proposals is not None:
            proposal = self.proposals[idx][:self.num_max_proposals]
            if not (proposal.shape[1] == 4 or proposal.shape[1] == 5):
                raise AssertionError(
                    'proposals should have shapes (n, 4) or (n, 5), '
                    'but found {}'.format(proposal.shape))
        else:
            proposal = None

        def prepare_single(img, scale, flip, proposal=None):
            _img, img_shape, pad_shape, scale_factor = self.img_transform(
                img, scale, flip, keep_ratio=self.resize_keep_ratio)
            _img = to_tensor(_img)
            _img_meta = dict(ori_shape=(img_info['height'], img_info['width'],
                                        3),
                             img_shape=img_shape,
                             pad_shape=pad_shape,
                             scale_factor=scale_factor,
                             flip=flip)
            if proposal is not None:
                if proposal.shape[1] == 5:
                    score = proposal[:, 4, None]
                    proposal = proposal[:, :4]
                else:
                    score = None
                _proposal = self.bbox_transform(proposal, img_shape,
                                                scale_factor, flip)
                _proposal = np.hstack([_proposal, score
                                       ]) if score is not None else _proposal
                _proposal = to_tensor(_proposal)
            else:
                _proposal = None
            return _img, _img_meta, _proposal

        imgs = []
        img_metas = []
        proposals = []
        for scale in self.img_scales:
            _img, _img_meta, _proposal = prepare_single(
                img, scale, False, proposal)
            imgs.append(_img)
            img_metas.append(DC(_img_meta, cpu_only=True))
            proposals.append(_proposal)
            if self.flip_ratio > 0:
                _img, _img_meta, _proposal = prepare_single(
                    img, scale, True, proposal)
                imgs.append(_img)
                img_metas.append(DC(_img_meta, cpu_only=True))
                proposals.append(_proposal)
        data = dict(img=imgs, img_meta=img_metas)
        if self.proposals is not None:
            data['proposals'] = proposals
        return data