Beispiel #1
0
    def get_seg_single(self,
                       cate_scores,
                       cate_labels,
                       seg_preds,
                       attention_maps,
                       strides,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False,
                       debug=False):

        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        seg_preds = seg_preds[:, 0]
        attention_maps = attention_maps[:, 0]

        # masks.
        seg_masks = seg_preds > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            return None

        seg_masks = seg_masks[keep, ...]
        seg_preds = seg_preds[keep, ...]
        attention_maps = attention_maps[keep, ...]
        sum_masks = sum_masks[keep]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # mask scoring.
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        attention_maps = attention_maps[sort_inds, ...]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks,
                                 cate_labels,
                                 cate_scores,
                                 kernel=cfg.kernel,
                                 sigma=cfg.sigma,
                                 sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= cfg.update_thr
        if keep.sum() == 0:
            return None
        seg_preds = seg_preds[keep, :, :]
        attention_maps = attention_maps[keep, ...]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_preds = seg_preds[sort_inds, :, :]
        attention_maps = attention_maps[sort_inds, ...]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr

        attention_maps = F.interpolate(attention_maps.unsqueeze(0),
                                       size=upsampled_size_out,
                                       mode='bilinear')[:, :, :h, :w]
        attention_masks = F.interpolate(attention_maps,
                                        size=ori_shape[:2],
                                        mode='bilinear').squeeze(0)
        attention_masks = attention_masks > 0
        return seg_masks, cate_labels, cate_scores
Beispiel #2
0
    def get_seg_single(self,
                       cate_preds,
                       seg_preds,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False, debug=False):
        assert len(cate_preds) == len(seg_preds)

        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # process.
        inds = (cate_preds > cfg.score_thr)
        # category scores.
        cate_scores = cate_preds[inds]
        if len(cate_scores) == 0:
            return None
        # category labels.
        inds = inds.nonzero()
        cate_labels = inds[:, 1]

        # strides.
        size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
        strides = cate_scores.new_ones(size_trans[-1])
        n_stage = len(self.seg_num_grids)
        strides[:size_trans[0]] *= self.strides[0]
        for ind_ in range(1, n_stage):
            strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_]
        strides = strides[inds[:, 0]]

        # masks.
        seg_preds = seg_preds[inds[:, 0]]
        seg_masks = seg_preds > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            return None

        seg_masks = seg_masks[keep, ...]
        seg_preds = seg_preds[keep, ...]
        sum_masks = sum_masks[keep]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # mask scoring.
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= cfg.update_thr
        if keep.sum() == 0:
            return None
        seg_preds = seg_preds[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_preds = seg_preds[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr
        return seg_masks, cate_labels, cate_scores
    def get_seg_single(
            self,
            cls_scores,  # (5)[h/s_i*w/s_i, 80]
            fcn_params,  # (5)[h/s_i*w/s_i, 169]
            mask_feat_pred,  # (1, 8, h/8, w/8)
            mlvl_points,
            img_shape,
            ori_shape,
            scale_factor,
            cfg,
            rescale=False):

        assert len(cls_scores) == len(fcn_params)
        featmap_size = mask_feat_pred.size()[-2:]  # 100,152
        upsampled_size_out = (featmap_size[0] * self.mask_downsample,
                              featmap_size[1] * self.mask_downsample)
        H, W, _ = img_shape

        cls_scores = torch.cat(cls_scores, dim=0)
        fcn_params = torch.cat(fcn_params, dim=0)

        inds = (cls_scores > cfg.score_thr)
        cate_scores = cls_scores[inds]
        if len(cate_scores) == 0:
            return None
        inds = inds.nonzero()  # [total_point, 2] row col
        cate_labels = inds[:, 1]  #[n]
        param_preds = fcn_params[inds[:, 0]]  #[n, 169]

        # forward
        weight, bias = self.parse_dynamic_params(param_preds)
        mask_feat_pred = mask_feat_pred.repeat((param_preds.shape[0], 1, 1, 1))
        mask_feat_pred = mask_feat_pred.reshape(1, -1, featmap_size[0],
                                                featmap_size[1])
        for i, (w, b) in enumerate(zip(weight, bias)):
            mask_feat_pred = F.conv2d(mask_feat_pred,
                                      w,
                                      bias=b,
                                      stride=1,
                                      padding=0,
                                      groups=param_preds.shape[0])
            if i < len(weight) - 1:
                mask_feat_pred = F.relu(mask_feat_pred)

        # mask_feat_pred : (1, num_pos, H, W) -> (num_pos, 1, 2H, 2W)
        mask_feat_pred = mask_feat_pred.reshape(-1, 1, featmap_size[0],
                                                featmap_size[1])
        mask_logits = aligned_bilinear(
            mask_feat_pred, int(self.mask_downsample / self.mask_out_stride))
        mask_logits.sigmoid()

        mask_logits = mask_logits.permute(1, 0, 2,
                                          3).squeeze(0)  # (num_pos, H, W)
        seg_masks = mask_logits > cfg.mask_thr

        # mask score
        sum_masks = seg_masks.sum((1, 2)).float()

        # remove 0
        ind_valid = sum_masks > 0
        mask_logits = mask_logits[ind_valid]
        seg_masks = seg_masks[ind_valid]
        sum_masks = sum_masks[ind_valid]
        cate_scores = cate_scores[ind_valid]
        cate_labels = cate_labels[ind_valid]

        seg_scores = (mask_logits * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks = seg_masks[sort_inds, :, :]
        mask_logits = mask_logits[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks,
                                 cate_labels,
                                 cate_scores,
                                 kernel=cfg.kernel,
                                 sigma=cfg.sigma,
                                 sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= cfg.update_thr
        if keep.sum() == 0:
            return None
        mask_logits = mask_logits[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        mask_logits = mask_logits[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        mask_logits = F.interpolate(mask_logits.unsqueeze(0),
                                    size=upsampled_size_out,
                                    mode='bilinear')[:, :, :H, :W]
        seg_masks = F.interpolate(mask_logits,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr
        return seg_masks, cate_labels, cate_scores
Beispiel #4
0
    def get_seg_single(self,
                       cate_preds,
                       seg_preds_x,
                       seg_preds_y,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False,
                       debug=False):

        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # trans trans_diff.
        trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
        trans_diff = torch.ones(trans_size[-1].item(),
                                device=cate_preds.device).long()
        num_grids = torch.ones(trans_size[-1].item(),
                               device=cate_preds.device).long()
        seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
        seg_diff = torch.ones(trans_size[-1].item(),
                              device=cate_preds.device).long()
        strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)

        n_stage = len(self.seg_num_grids)
        trans_diff[:trans_size[0]] *= 0
        seg_diff[:trans_size[0]] *= 0
        num_grids[:trans_size[0]] *= self.seg_num_grids[0]
        strides[:trans_size[0]] *= self.strides[0]

        for ind_ in range(1, n_stage):
            trans_diff[trans_size[ind_ -
                                  1]:trans_size[ind_]] *= trans_size[ind_ - 1]
            seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ -
                                                                        1]
            num_grids[trans_size[ind_ -
                                 1]:trans_size[ind_]] *= self.seg_num_grids[
                                     ind_]
            strides[trans_size[ind_ -
                               1]:trans_size[ind_]] *= self.strides[ind_]

        # process.
        inds = (cate_preds > cfg.score_thr)
        cate_scores = cate_preds[inds]

        inds = inds.nonzero()
        trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
        seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
        num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
        strides = torch.index_select(strides, dim=0, index=inds[:, 0])

        y_inds = (inds[:, 0] - trans_diff) // num_grids
        x_inds = (inds[:, 0] - trans_diff) % num_grids
        y_inds += seg_diff
        x_inds += seg_diff

        cate_labels = inds[:, 1]
        seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
        seg_masks = seg_masks_soft > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()
        keep = sum_masks > strides

        seg_masks_soft = seg_masks_soft[keep, ...]
        seg_masks = seg_masks[keep, ...]
        cate_scores = cate_scores[keep]
        sum_masks = sum_masks[keep]
        cate_labels = cate_labels[keep]
        # maskness
        seg_score = (seg_masks_soft * seg_masks.float()).sum(
            (1, 2)) / sum_masks
        cate_scores *= seg_score

        if len(cate_scores) == 0:
            return None

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        seg_masks = seg_masks[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        sum_masks = sum_masks[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks,
                                 cate_labels,
                                 cate_scores,
                                 kernel=cfg.kernel,
                                 sigma=cfg.sigma,
                                 sum_masks=sum_masks)

        keep = cate_scores >= cfg.update_thr
        seg_masks_soft = seg_masks_soft[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]
        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
                                       size=upsampled_size_out,
                                       mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_masks_soft,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr
        return seg_masks, cate_labels, cate_scores
Beispiel #5
0
    def aug_test(self, imgs, img_metas, rescale=False):
        """Test with augmentations.
        If rescale is False, then returned masks will fit the scale of imgs[0].
        """
        ori_shape = img_metas[0][0]['ori_shape'][:2]

        meta_result_list = []
        for img, img_meta in zip(imgs, img_metas):
            x = self.extract_feat(img)
            seg_preds, cate_preds = self.bbox_head(x, eval=True)
            img_shape = img_meta[0]['img_shape']
            img_result_list = self.bbox_head.get_seg_aug(
                seg_preds, cate_preds, img_shape, self.test_cfg)
            meta_result_list.append(img_result_list)

        img_output = []
        for img_result in zip(*meta_result_list):
            seg_masks, seg_preds, sum_masks, cate_scores, cate_labels = map(
                list, zip(*img_result))
            unified_size = tuple(seg_masks[0].shape[-2:])
            for i in range(1, len(seg_masks)):
                seg_masks[i] = F.interpolate(seg_masks[i].float().unsqueeze(0),
                                             size=unified_size,
                                             mode='bilinear',
                                             align_corners=False).squeeze(0)
                seg_preds[i] = F.interpolate(seg_preds[i].unsqueeze(0),
                                             size=unified_size,
                                             mode='bilinear',
                                             align_corners=False).squeeze(0)
                if img_metas[i][0]['flip']:
                    seg_masks[i] = torch.flip(seg_masks[i], dims=[2])
                    seg_preds[i] = torch.flip(seg_preds[i], dims=[2])
            seg_masks = torch.cat(seg_masks, dim=0)
            seg_preds = torch.cat(seg_preds, dim=0)
            sum_masks = torch.cat(sum_masks, dim=0)
            cate_scores = torch.cat(cate_scores, dim=0)
            cate_labels = torch.cat(cate_labels, dim=0)
            # import cv2
            # for i, seg_mask in enumerate(seg_masks):
            #     cv2.imwrite('/versa/dyy/SOLO/tta/{}.png'.format(i),
            #                 seg_mask.cpu().numpy().astype(np.uint8) * 255)

            # sort and keep top nms_pre
            sort_inds = torch.argsort(cate_scores, descending=True)
            if len(sort_inds) > self.test_cfg.nms_pre:
                sort_inds = sort_inds[:self.test_cfg.nms_pre]
            seg_masks = seg_masks[sort_inds, :, :]
            seg_preds = seg_preds[sort_inds, :, :]
            sum_masks = sum_masks[sort_inds]
            cate_scores = cate_scores[sort_inds]
            cate_labels = cate_labels[sort_inds]

            # Matrix NMS
            cate_scores = matrix_nms(seg_masks,
                                     cate_labels,
                                     cate_scores,
                                     kernel=self.test_cfg.kernel,
                                     sigma=self.test_cfg.sigma,
                                     sum_masks=sum_masks)

            # filter.
            keep = cate_scores >= self.test_cfg.update_thr
            if keep.sum() == 0:
                return None
            seg_preds = seg_preds[keep, :, :]
            cate_scores = cate_scores[keep]
            cate_labels = cate_labels[keep]

            # sort and keep top_k
            sort_inds = torch.argsort(cate_scores, descending=True)
            if len(sort_inds) > self.test_cfg.max_per_img:
                sort_inds = sort_inds[:self.test_cfg.max_per_img]
            seg_preds = seg_preds[sort_inds, :, :]
            cate_scores = cate_scores[sort_inds]
            cate_labels = cate_labels[sort_inds]

            seg_masks = F.interpolate(seg_preds.unsqueeze(0),
                                      size=ori_shape,
                                      mode='bilinear',
                                      align_corners=False).squeeze(0)
            seg_masks = seg_masks > self.test_cfg.mask_thr
            output = (seg_masks, cate_labels, cate_scores)
            img_output.append(output)

        return img_output
Beispiel #6
0
    def get_seg_single_threshed(self,
                       cate_preds,
                       seg_preds,
                       cate_labels,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False, debug=False):
        assert len(cate_preds) == len(seg_preds)

        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # process.
        # inds = (cate_preds > cfg.score_thr)
        # category scores.
        cate_scores = cate_preds
        if len(cate_scores) == 0:
            return None

        # masks.
        seg_masks = seg_preds > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()

        # mask scoring. ## average confidence on mask area
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= cfg.update_thr
        if keep.sum() == 0:
            return None
        seg_preds = seg_preds[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_preds = seg_preds[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds,
                                  size=ori_shape[:2],
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr
        return seg_masks, cate_labels, cate_scores