Пример #1
0
    def get_roi_rel_points_train(self, mask_pred, labels, cfg):
        """Get ``num_points`` most uncertain points with random points during
        train.

        Sample points in [0, 1] x [0, 1] coordinate space based on their
        uncertainty. The uncertainties are calculated for each point using
        '_get_uncertainty()' function that takes point's logit prediction as
        input.

        Args:
            mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
                prediction.
            labels (list): The ground truth class for each instance.
            cfg (dict): Training config of point head.

        Returns:
            point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
                that contains the coordinates sampled points.
        """
        num_points = cfg.num_points
        oversample_ratio = cfg.oversample_ratio
        importance_sample_ratio = cfg.importance_sample_ratio
        assert oversample_ratio >= 1
        assert 0 <= importance_sample_ratio <= 1
        batch_size = mask_pred.shape[0]
        num_sampled = int(num_points * oversample_ratio)
        point_coords = torch.rand(batch_size,
                                  num_sampled,
                                  2,
                                  device=mask_pred.device)
        point_logits = point_sample(mask_pred, point_coords)
        # It is crucial to calculate uncertainty based on the sampled
        # prediction value for the points. Calculating uncertainties of the
        # coarse predictions first and sampling them for points leads to
        # incorrect results.  To illustrate this: assume uncertainty func(
        # logits)=-abs(logits), a sampled point between two coarse
        # predictions with -1 and 1 logits has 0 logits, and therefore 0
        # uncertainty value. However, if we calculate uncertainties for the
        # coarse predictions first, both will have -1 uncertainty,
        # and sampled point will get -1 uncertainty.
        point_uncertainties = self._get_uncertainty(point_logits, labels)
        num_uncertain_points = int(importance_sample_ratio * num_points)
        num_random_points = num_points - num_uncertain_points
        idx = torch.topk(point_uncertainties[:, 0, :],
                         k=num_uncertain_points,
                         dim=1)[1]
        shift = num_sampled * torch.arange(
            batch_size, dtype=torch.long, device=mask_pred.device)
        idx += shift[:, None]
        point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
            batch_size, num_uncertain_points, 2)
        if num_random_points > 0:
            rand_roi_coords = torch.rand(batch_size,
                                         num_random_points,
                                         2,
                                         device=mask_pred.device)
            point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
        return point_coords
Пример #2
0
 def _get_target_single(self, rois, rel_roi_points, gt_masks, cfg):
     """Get training target of KeyPointMaskHead for each image."""
     num_pos = rois.size(0)
     num_points = cfg.num_points
     if num_pos > 0:
         gt_masks_th = (gt_masks.to_tensor(rois.dtype, rois.device))
         gt_masks_th = gt_masks_th.unsqueeze(1)
         rel_img_points = rel_roi_point_to_rel_img_point(
             rois, rel_roi_points, gt_masks_th.shape[2:])
         point_targets = point_sample(gt_masks_th,
                                      rel_img_points).squeeze(1)
     else:
         point_targets = rois.new_zeros((0, num_points))
     return point_targets
Пример #3
0
 def _get_target_single(self, rois, rel_roi_points, pos_assigned_gt_inds,
                        gt_masks, cfg):
     num_pos = rois.size(0)
     num_points = cfg.num_points
     if num_pos > 0:
         gt_masks_th = (gt_masks.to_tensor(rois.dtype,
                                           rois.device).index_select(
                                               0, pos_assigned_gt_inds))
         gt_masks_th = gt_masks_th.unsqueeze(1)
         rel_img_points = rel_roi_point2rel_img_point(
             rois, rel_roi_points, gt_masks_th.shape[2:])
         point_targets = point_sample(gt_masks_th,
                                      rel_img_points).squeeze(1)
     else:
         point_targets = rois.new_zeros((0, num_points))
     return point_targets
Пример #4
0
    def _mask_point_forward_train(self, x, sampling_results, mask_pred,
                                  gt_masks, img_metas):

        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
        rel_roi_points = self.point_head.get_roi_rel_points_train(
            mask_pred, pos_labels, cfg=self.train_cfg)
        rois = bbox2roi([res.pos_bboxes for res in sampling_results])

        fine_grained_point_feats = self._get_fine_grained_point_feats(
            x, rois, rel_roi_points, img_metas)
        coarse_point_feats = point_sample(mask_pred, rel_roi_points)
        mask_point_pred = self.point_head(fine_grained_point_feats,
                                          coarse_point_feats)
        mask_point_target = self.point_head.get_targets(
            rois, rel_roi_points, sampling_results, gt_masks, self.train_cfg)
        loss_mask_point = self.point_head.loss(mask_point_pred,
                                               mask_point_target, pos_labels)

        return loss_mask_point
Пример #5
0
 def _get_fine_grained_point_feats(self, x, rois, rel_roi_points,
                                   img_metas):
     num_imgs = len(img_metas)
     fine_grained_feats = []
     for idx in range(self.mask_roi_extractor.num_inputs):
         feats = x[idx]
         spatial_scale = 1. / float(
             self.mask_roi_extractor.featmap_strides[idx])
         point_feats = []
         for batch_ind in range(num_imgs):
             # unravel batch dim
             feat = feats[batch_ind].unsqueeze(0)
             inds = (rois[:, 0].long() == batch_ind)
             if inds.any():
                 rel_img_points = rel_roi_point2rel_img_point(
                     rois[inds], rel_roi_points[inds], feat.shape[2:],
                     spatial_scale).unsqueeze(0)
                 point_feat = point_sample(feat, rel_img_points)
                 point_feat = point_feat.squeeze(0).transpose(0, 1)
                 point_feats.append(point_feat)
         fine_grained_feats.append(torch.cat(point_feats, dim=0))
     return torch.cat(fine_grained_feats, dim=1)
Пример #6
0
    def _mask_point_forward_train(self, x, gt_bboxes, mask_pred, gt_masks,
                                  gt_labels, img_metas):
        """Run forward function and calculate loss for point head in
        training."""
        # gt_labels = ([Num])
        # mask_pred = ([Num, cls, 7, 7])
        rel_roi_points = self.point_head.get_roi_rel_points_train(
            mask_pred, gt_labels, cfg=self.train_cfg)

        ## WARNING: if use FPN for multi layer change [x] to normal x
        if isinstance(x, torch.Tensor): x = [x]

        fine_grained_point_feats = self._get_fine_grained_point_feats(
            x, gt_bboxes, rel_roi_points, img_metas)
        coarse_point_feats = point_sample(mask_pred, rel_roi_points)
        mask_point_pred = self.point_head(fine_grained_point_feats,
                                          coarse_point_feats)
        mask_point_target = self.point_head.get_targets(
            gt_bboxes, rel_roi_points, gt_masks, img_metas, self.train_cfg)
        loss_mask_point = self.point_head.loss(mask_point_pred,
                                               mask_point_target, gt_labels)

        return loss_mask_point
    def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
                                 img_metas):
        """Mask refining process with point head in testing"""
        refined_mask_pred = mask_pred.clone()
        for subdivision_step in range(self.test_cfg.subdivision_steps):
            refined_mask_pred = F.interpolate(
                refined_mask_pred,
                scale_factor=self.test_cfg.scale_factor,
                mode='bilinear',
                align_corners=False)
            # If `subdivision_num_points` is larger or equal to the
            # resolution of the next step, then we can skip this step
            num_rois, channels, mask_height, mask_width = \
                refined_mask_pred.shape
            if (self.test_cfg.subdivision_num_points >=
                    self.test_cfg.scale_factor**2 * mask_height * mask_width
                    and
                    subdivision_step < self.test_cfg.subdivision_steps - 1):
                continue
            point_indices, rel_roi_points = \
                self.point_head.get_roi_rel_points_test(
                    refined_mask_pred, label_pred, cfg=self.test_cfg)
            fine_grained_point_feats = self._get_fine_grained_point_feats(
                x, rois, rel_roi_points, img_metas)
            coarse_point_feats = point_sample(mask_pred, rel_roi_points)
            mask_point_pred = self.point_head(fine_grained_point_feats,
                                              coarse_point_feats)

            point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
            refined_mask_pred = refined_mask_pred.reshape(
                num_rois, channels, mask_height * mask_width)
            refined_mask_pred = refined_mask_pred.scatter_(
                2, point_indices, mask_point_pred)
            refined_mask_pred = refined_mask_pred.view(num_rois, channels,
                                                       mask_height, mask_width)

        return refined_mask_pred