Exemple #1
0
    def inference_single_image(self, cate_preds, seg_preds, featmap_size,
                               img_shape, ori_shape):
        """
        Args:
            cate_preds, seg_preds: see: method: `inference`.
            featmap_size (list[tuple]): feature map size per level.
            img_shape (tuple): the size of the image fed into the model (height and width).
            ori_shape (tuple): original image shape (height and width).

        Returns:
            result (Instances): predicted results of single image after post-processing.
        """
        assert len(cate_preds) == len(seg_preds)
        result = Instances(ori_shape)

        # overall info.
        h, w = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # process.
        inds = (cate_preds > self.score_threshold)
        # category scores.
        cate_scores = cate_preds[inds]
        if len(cate_scores) == 0:
            return result
        # category labels.
        inds = inds.nonzero(as_tuple=False)
        cate_labels = inds[:, 1]

        # strides.
        size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(
            0)  # [1600, 2896, 3472, 3728, 3872]
        strides = cate_scores.new_ones(size_trans[-1])
        n_stage = len(self.seg_num_grids)
        strides[:size_trans[0]] *= self.feature_strides[0]
        for ind_ in range(1, n_stage):
            strides[size_trans[ind_ -
                               1]:size_trans[ind_]] *= self.feature_strides[
                                   ind_]
        strides = strides[inds[:, 0]]

        # masks.
        seg_preds = seg_preds[inds[:, 0]]
        seg_masks = seg_preds > self.mask_threshold
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            return result

        seg_masks = seg_masks[keep, ...]
        seg_preds = seg_preds[keep, ...]
        sum_masks = sum_masks[keep]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # mask scoring.
        seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_scores

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.nms_per_image:
            sort_inds = sort_inds[:self.nms_per_image]
        seg_masks = seg_masks[sort_inds, :, :]
        seg_preds = seg_preds[sort_inds, :, :]
        sum_masks = sum_masks[sort_inds]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks,
                                 cate_labels,
                                 cate_scores,
                                 kernel=self.nms_kernel,
                                 sigma=self.nms_sigma,
                                 sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= self.update_threshold
        if keep.sum() == 0:
            return result
        seg_preds = seg_preds[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.max_detections_per_image:
            sort_inds = sort_inds[:self.max_detections_per_image]
        seg_preds = seg_preds[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
                                  size=upsampled_size_out,
                                  mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_preds, size=ori_shape,
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > self.mask_threshold

        seg_masks = BitMasks(seg_masks)
        result.pred_masks = seg_masks
        result.pred_boxes = seg_masks.get_bounding_boxes()
        result.scores = cate_scores
        result.pred_classes = cate_labels
        return result
Exemple #2
0
    def inference_single_image(self, cate_preds, seg_preds_x, seg_preds_y,
                               featmap_size, img_shape, ori_shape):
        result = Instances(ori_shape)

        # overall info.
        h, w = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)

        # trans trans_diff.
        trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()
        trans_diff = torch.ones(trans_size[-1].item(),
                                device=self.device).long()
        num_grids = torch.ones(trans_size[-1].item(),
                               device=self.device).long()
        seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
        seg_diff = torch.ones(trans_size[-1].item(), device=self.device).long()
        strides = torch.ones(trans_size[-1].item(), device=self.device)

        n_stage = len(self.seg_num_grids)
        trans_diff[:trans_size[0]] *= 0
        seg_diff[:trans_size[0]] *= 0
        num_grids[:trans_size[0]] *= self.seg_num_grids[0]
        strides[:trans_size[0]] *= self.feature_strides[0]

        for ind_ in range(1, n_stage):
            trans_diff[trans_size[ind_ -
                                  1]:trans_size[ind_]] *= trans_size[ind_ - 1]
            seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ -
                                                                        1]
            num_grids[trans_size[ind_ -
                                 1]:trans_size[ind_]] *= self.seg_num_grids[
                                     ind_]
            strides[trans_size[ind_ -
                               1]:trans_size[ind_]] *= self.feature_strides[
                                   ind_]

        # process.
        inds = (cate_preds > self.score_threshold)
        # category scores.
        cate_scores = cate_preds[inds]

        # category labels.
        inds = inds.nonzero(as_tuple=False)
        trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
        seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
        num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
        strides = torch.index_select(strides, dim=0, index=inds[:, 0])

        y_inds = (inds[:, 0] - trans_diff) // num_grids
        x_inds = (inds[:, 0] - trans_diff) % num_grids
        y_inds += seg_diff
        x_inds += seg_diff

        cate_labels = inds[:, 1]
        seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]
        seg_masks = seg_masks_soft > self.mask_threshold
        sum_masks = seg_masks.sum((1, 2)).float()

        # filter.
        keep = sum_masks > strides
        if keep.sum() == 0:
            return result

        seg_masks_soft = seg_masks_soft[keep, ...]
        seg_masks = seg_masks[keep, ...]
        cate_scores = cate_scores[keep]
        sum_masks = sum_masks[keep]
        cate_labels = cate_labels[keep]

        # mask scoring
        seg_score = (seg_masks_soft * seg_masks.float()).sum(
            (1, 2)) / sum_masks
        cate_scores *= seg_score

        if len(cate_scores) == 0:
            return result

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.nms_per_image:
            sort_inds = sort_inds[:self.nms_per_image]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        seg_masks = seg_masks[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        sum_masks = sum_masks[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks,
                                 cate_labels,
                                 cate_scores,
                                 kernel=self.nms_kernel,
                                 sigma=self.nms_sigma,
                                 sum_masks=sum_masks)

        # filter.
        keep = cate_scores >= self.update_threshold
        seg_masks_soft = seg_masks_soft[keep, :, :]
        cate_scores = cate_scores[keep]
        cate_labels = cate_labels[keep]

        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > self.max_detections_per_image:
            sort_inds = sort_inds[:self.max_detections_per_image]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
                                       size=upsampled_size_out,
                                       mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_masks_soft,
                                  size=ori_shape,
                                  mode='bilinear').squeeze(0)
        seg_masks = seg_masks > self.mask_threshold

        seg_masks = BitMasks(seg_masks)
        result.pred_masks = seg_masks
        result.pred_boxes = seg_masks.get_bounding_boxes()
        result.scores = cate_scores
        result.pred_classes = cate_labels
        return result