def _mask_point_forward_once(self, x, sampling_results, mask_pred, img_metas): """Mask refining process once with point head in training.""" label_pred = torch.cat([res.pos_gt_labels for res in sampling_results]) rois = bbox2roi([res.pos_bboxes for res in sampling_results]) refined_mask_pred = mask_pred.clone() refined_mask_pred = F.interpolate( refined_mask_pred.unsqueeze(1), scale_factor=self.test_cfg.scale_factor, mode='bilinear', align_corners=False) num_rois, channels, mask_height, mask_width = \ refined_mask_pred.shape point_indices, rel_roi_points = \ self.point_head.get_roi_rel_points_test( refined_mask_pred, label_pred, cfg=self.test_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, rois, rel_roi_points, img_metas) coarse_point_feats = point_sample(mask_pred.unsqueeze(1), rel_roi_points) mask_point_pred = self.point_head(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_mask_pred = refined_mask_pred.reshape(num_rois, channels, mask_height * mask_width) refined_mask_pred = refined_mask_pred.scatter_(2, point_indices, mask_point_pred) refined_mask_pred = refined_mask_pred.view(num_rois, channels, mask_height, mask_width) return refined_mask_pred
def _onnx_get_fine_grained_point_feats(self, x, rois, rel_roi_points): """Export the process of sampling fine grained feats to onnx. Args: x (tuple[Tensor]): Feature maps of all scale level. rois (Tensor): shape (num_rois, 5). rel_roi_points (Tensor): A tensor of shape (num_rois, num_points, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the [mask_height, mask_width] grid. Returns: Tensor: The fine grained features for each points, has shape (num_rois, feats_channels, num_points). """ batch_size = x[0].shape[0] num_rois = rois.shape[0] fine_grained_feats = [] for idx in range(self.mask_roi_extractor.num_inputs): feats = x[idx] spatial_scale = 1. / float( self.mask_roi_extractor.featmap_strides[idx]) rel_img_points = rel_roi_point_to_rel_img_point( rois, rel_roi_points, feats, spatial_scale) channels = feats.shape[1] num_points = rel_img_points.shape[1] rel_img_points = rel_img_points.reshape(batch_size, -1, num_points, 2) point_feats = point_sample(feats, rel_img_points) point_feats = point_feats.transpose(1, 2).reshape( num_rois, channels, num_points) fine_grained_feats.append(point_feats) return torch.cat(fine_grained_feats, dim=1)
def get_roi_rel_points_train(self, mask_pred, labels, cfg): """Get ``num_points`` most uncertain points with random points during train. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties are calculated for each point using '_get_uncertainty()' function that takes point's logit prediction as input. Args: mask_pred (Tensor): A tensor of shape (num_rois, num_classes, mask_height, mask_width) for class-specific or class-agnostic prediction. labels (list): The ground truth class for each instance. cfg (dict): Training config of point head. Returns: point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) that contains the coordinates sampled points. """ num_points = cfg.num_points oversample_ratio = cfg.oversample_ratio importance_sample_ratio = cfg.importance_sample_ratio assert oversample_ratio >= 1 assert 0 <= importance_sample_ratio <= 1 batch_size = mask_pred.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) point_logits = point_sample(mask_pred, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the # coarse predictions first and sampling them for points leads to # incorrect results. To illustrate this: assume uncertainty func( # logits)=-abs(logits), a sampled point between two coarse # predictions with -1 and 1 logits has 0 logits, and therefore 0 # uncertainty value. However, if we calculate uncertainties for the # coarse predictions first, both will have -1 uncertainty, # and sampled point will get -1 uncertainty. point_uncertainties = self._get_uncertainty(point_logits, labels) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( batch_size, dtype=torch.long, device=mask_pred.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( batch_size, num_uncertain_points, 2) if num_random_points > 0: rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) return point_coords
def get_points_train(self, seg_logits, uncertainty_func, cfg): """Sample points for training. Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. Args: seg_logits (Tensor): Semantic segmentation logits, shape ( batch_size, num_classes, height, width). uncertainty_func (func): uncertainty calculation function. cfg (dict): Training config of point head. Returns: point_coords (Tensor): A tensor of shape (batch_size, num_points, 2) that contains the coordinates of ``num_points`` sampled points. """ num_points = cfg.num_points oversample_ratio = cfg.oversample_ratio importance_sample_ratio = cfg.importance_sample_ratio assert oversample_ratio >= 1 assert 0 <= importance_sample_ratio <= 1 batch_size = seg_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand(batch_size, num_sampled, 2, device=seg_logits.device) point_logits = point_sample(seg_logits, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the # coarse predictions first and sampling them for points leads to # incorrect results. To illustrate this: assume uncertainty func( # logits)=-abs(logits), a sampled point between two coarse # predictions with -1 and 1 logits has 0 logits, and therefore 0 # uncertainty value. However, if we calculate uncertainties for the # coarse predictions first, both will have -1 uncertainty, # and sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( batch_size, dtype=torch.long, device=seg_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( batch_size, num_uncertain_points, 2) if num_random_points > 0: rand_point_coords = torch.rand(batch_size, num_random_points, 2, device=seg_logits.device) point_coords = torch.cat((point_coords, rand_point_coords), dim=1) return point_coords
def _mask_point_forward_test(self, x, rois, label_pred, mask_pred, img_metas): """Mask refining process with point head in testing. Args: x (tuple[Tensor]): Feature maps of all scale level. rois (Tensor): shape (num_rois, 5). label_pred (Tensor): The predication class for each rois. mask_pred (Tensor): The predication coarse masks of shape (num_rois, num_classes, small_size, small_size). img_metas (list[dict]): Image meta info. Returns: Tensor: The refined masks of shape (num_rois, num_classes, large_size, large_size). """ refined_mask_pred = mask_pred.clone() for subdivision_step in range(self.test_cfg.subdivision_steps): refined_mask_pred = F.interpolate( refined_mask_pred, scale_factor=self.test_cfg.scale_factor, mode='bilinear', align_corners=False) # If `subdivision_num_points` is larger or equal to the # resolution of the next step, then we can skip this step num_rois, channels, mask_height, mask_width = \ refined_mask_pred.shape if (self.test_cfg.subdivision_num_points >= self.test_cfg.scale_factor**2 * mask_height * mask_width and subdivision_step < self.test_cfg.subdivision_steps - 1): continue point_indices, rel_roi_points = \ self.point_head.get_roi_rel_points_test( refined_mask_pred, label_pred, cfg=self.test_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, rois, rel_roi_points, img_metas) coarse_point_feats = point_sample(mask_pred, rel_roi_points) mask_point_pred = self.point_head(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_mask_pred = refined_mask_pred.reshape( num_rois, channels, mask_height * mask_width) refined_mask_pred = refined_mask_pred.scatter_( 2, point_indices, mask_point_pred) refined_mask_pred = refined_mask_pred.view(num_rois, channels, mask_height, mask_width) return refined_mask_pred
def _get_target_single(self, rois, rel_roi_points, pos_assigned_gt_inds, gt_masks, cfg): """Get training target of MaskPointHead for each image.""" num_pos = rois.size(0) num_points = cfg.num_points if num_pos > 0: gt_masks_th = (gt_masks.to_tensor(rois.dtype, rois.device).index_select( 0, pos_assigned_gt_inds)) gt_masks_th = gt_masks_th.unsqueeze(1) rel_img_points = rel_roi_point_to_rel_img_point( rois, rel_roi_points, gt_masks_th.shape[2:]) point_targets = point_sample(gt_masks_th, rel_img_points).squeeze(1) else: point_targets = rois.new_zeros((0, num_points)) return point_targets
def _get_coarse_point_feats(self, prev_output, points): """Sample from fine grained features. Args: prev_output (list[Tensor]): Prediction of previous decode head. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, num_classes, num_points). """ coarse_feats = point_sample( prev_output, points, align_corners=self.align_corners) return coarse_feats
def _mask_point_forward_train(self, x, sampling_results, mask_pred, gt_masks, img_metas): """Run forward function and calculate loss for point head in training.""" pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) rel_roi_points = self.point_head.get_roi_rel_points_train( mask_pred, pos_labels, cfg=self.train_cfg) rois = bbox2roi([res.pos_bboxes for res in sampling_results]) fine_grained_point_feats = self._get_fine_grained_point_feats( x, rois, rel_roi_points, img_metas) coarse_point_feats = point_sample(mask_pred, rel_roi_points) mask_point_pred = self.point_head(fine_grained_point_feats, coarse_point_feats) mask_point_target = self.point_head.get_targets( rois, rel_roi_points, sampling_results, gt_masks, self.train_cfg) loss_mask_point = self.point_head.loss(mask_point_pred, mask_point_target, pos_labels) return loss_mask_point
def _get_fine_grained_point_feats(self, x, points): """Sample from fine grained features. Args: x (list[Tensor]): Feature pyramid from by neck or backbone. points (Tensor): Point coordinates, shape (batch_size, num_points, 2). Returns: fine_grained_feats (Tensor): Sampled fine grained feature, shape (batch_size, sum(channels of x), num_points). """ fine_grained_feats_list = [ point_sample(_, points, align_corners=self.align_corners) for _ in x ] if len(fine_grained_feats_list) > 1: fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) else: fine_grained_feats = fine_grained_feats_list[0] return fine_grained_feats
def _mask_point_forward_test(self, x, rois, label_pred, mask_pred, img_metas): """Mask refining process with point head in testing.""" refined_mask_pred = mask_pred.clone() for subdivision_step in range(self.test_cfg.subdivision_steps): refined_mask_pred = F.interpolate( refined_mask_pred, scale_factor=self.test_cfg.scale_factor, mode='bilinear', align_corners=False) # If `subdivision_num_points` is larger or equal to the # resolution of the next step, then we can skip this step num_rois, channels, mask_height, mask_width = \ refined_mask_pred.shape if (self.test_cfg.subdivision_num_points >= self.test_cfg.scale_factor**2 * mask_height * mask_width and subdivision_step < self.test_cfg.subdivision_steps - 1): continue point_indices, rel_roi_points = \ self.point_head.get_roi_rel_points_test( refined_mask_pred, label_pred, cfg=self.test_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, rois, rel_roi_points, img_metas) coarse_point_feats = point_sample(mask_pred, rel_roi_points) mask_point_pred = self.point_head(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_mask_pred = refined_mask_pred.reshape( num_rois, channels, mask_height * mask_width) refined_mask_pred = refined_mask_pred.scatter_( 2, point_indices, mask_point_pred) refined_mask_pred = refined_mask_pred.view(num_rois, channels, mask_height, mask_width) return refined_mask_pred
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self._transform_inputs(inputs) with torch.no_grad(): points = self.get_points_train(prev_output, calculate_uncertainty, cfg=train_cfg) fine_grained_point_feats = self._get_fine_grained_point_feats( x, points) coarse_point_feats = self._get_coarse_point_feats(prev_output, points) point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_label = point_sample(gt_semantic_seg.float(), points, mode='nearest', align_corners=self.align_corners) point_label = point_label.squeeze(1).long() losses = self.losses(point_logits, point_label) return losses
def _get_fine_grained_point_feats(self, x, rois, rel_roi_points, img_metas): """Sample fine grained feats from each level feature map and concatenate them together. Args: x (tuple[Tensor]): Feature maps of all scale level. rois (Tensor): shape (num_rois, 5). rel_roi_points (Tensor): A tensor of shape (num_rois, num_points, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the [mask_height, mask_width] grid. img_metas (list[dict]): Image meta info. Returns: Tensor: The fine grained features for each points, has shape (num_rois, feats_channels, num_points). """ num_imgs = len(img_metas) fine_grained_feats = [] for idx in range(self.mask_roi_extractor.num_inputs): feats = x[idx] spatial_scale = 1. / float( self.mask_roi_extractor.featmap_strides[idx]) point_feats = [] for batch_ind in range(num_imgs): # unravel batch dim feat = feats[batch_ind].unsqueeze(0) inds = (rois[:, 0].long() == batch_ind) if inds.any(): rel_img_points = rel_roi_point_to_rel_img_point( rois[inds], rel_roi_points[inds], feat.shape[2:], spatial_scale).unsqueeze(0) point_feat = point_sample(feat, rel_img_points) point_feat = point_feat.squeeze(0).transpose(0, 1) point_feats.append(point_feat) fine_grained_feats.append(torch.cat(point_feats, dim=0)) return torch.cat(fine_grained_feats, dim=1)
def _get_fine_grained_point_feats(self, x, rois, rel_roi_points, img_metas): """Sample fine grained feats from each level feature map and concatenate them together.""" num_imgs = len(img_metas) fine_grained_feats = [] for idx in range(self.mask_roi_extractor.num_inputs): feats = x[idx] spatial_scale = 1. / float( self.mask_roi_extractor.featmap_strides[idx]) point_feats = [] for batch_ind in range(num_imgs): # unravel batch dim feat = feats[batch_ind].unsqueeze(0) inds = (rois[:, 0].long() == batch_ind) if inds.any(): rel_img_points = rel_roi_point_to_rel_img_point( rois[inds], rel_roi_points[inds], feat.shape[2:], spatial_scale).unsqueeze(0) point_feat = point_sample(feat, rel_img_points) point_feat = point_feat.squeeze(0).transpose(0, 1) point_feats.append(point_feat) fine_grained_feats.append(torch.cat(point_feats, dim=0)) return torch.cat(fine_grained_feats, dim=1)
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): """Loss function for outputs from a single decoder layer. Args: cls_scores (Tensor): Mask score logits from a single decoder layer for all images. Shape (batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. mask_preds (Tensor): Mask logits for a pixel decoder for all images. Shape (batch_size, num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for each image, each with shape (num_gts, ). gt_masks_list (list[Tensor]): Ground truth mask for each image, each with shape (num_gts, h, w). img_metas (list[dict]): List of image meta information. Returns: tuple[Tensor]: Loss components for outputs from a single \ decoder layer. """ num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] mask_preds_list = [mask_preds[i] for i in range(num_imgs)] (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) # shape (batch_size, num_queries) labels = torch.stack(labels_list, dim=0) # shape (batch_size, num_queries) label_weights = torch.stack(label_weights_list, dim=0) # shape (num_total_gts, h, w) mask_targets = torch.cat(mask_targets_list, dim=0) # shape (batch_size, num_queries) mask_weights = torch.stack(mask_weights_list, dim=0) # classfication loss # shape (batch_size * num_queries, ) cls_scores = cls_scores.flatten(0, 1) labels = labels.flatten(0, 1) label_weights = label_weights.flatten(0, 1) class_weight = cls_scores.new_tensor(self.class_weight) loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) num_total_masks = max(num_total_masks, 1) # extract positive ones # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) mask_preds = mask_preds[mask_weights > 0] if mask_targets.shape[0] == 0: # zero match loss_dice = mask_preds.sum() loss_mask = mask_preds.sum() return loss_cls, loss_mask, loss_dice with torch.no_grad(): points_coords = get_uncertain_point_coords_with_randomness( mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio) # shape (num_total_gts, h, w) -> (num_total_gts, num_points) mask_point_targets = point_sample( mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) # shape (num_queries, h, w) -> (num_queries, num_points) mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) # dice loss loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) # mask loss # shape (num_queries, num_points) -> (num_queries * num_points, ) mask_point_preds = mask_point_preds.reshape(-1) # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) mask_point_targets = mask_point_targets.reshape(-1) loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) return loss_cls, loss_mask, loss_dice
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): """Compute classification and mask targets for one image. Args: cls_score (Tensor): Mask score logits from a single decoder layer for one image. Shape (num_queries, cls_out_channels). mask_pred (Tensor): Mask logits for a single decoder layer for one image. Shape (num_queries, h, w). gt_labels (Tensor): Ground truth class indices for one image with shape (num_gts, ). gt_masks (Tensor): Ground truth mask for each image, each with shape (num_gts, h, w). img_metas (dict): Image informtation. Returns: tuple[Tensor]: A tuple containing the following for one image. - labels (Tensor): Labels of each image. \ shape (num_queries, ). - label_weights (Tensor): Label weights of each image. \ shape (num_queries, ). - mask_targets (Tensor): Mask targets of each image. \ shape (num_queries, h, w). - mask_weights (Tensor): Mask weights of each image. \ shape (num_queries, ). - pos_inds (Tensor): Sampled positive indices for each \ image. - neg_inds (Tensor): Sampled negative indices for each \ image. """ # sample points num_queries = cls_score.shape[0] num_gts = gt_labels.shape[0] point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) # shape (num_queries, num_points) mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) # shape (num_gts, num_points) gt_points_masks = point_sample( gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) # assign and sample assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds # label target labels = gt_labels.new_full((self.num_queries, ), self.num_classes, dtype=torch.long) labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] label_weights = gt_labels.new_ones((self.num_queries, )) # mask target mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] mask_weights = mask_pred.new_zeros((self.num_queries, )) mask_weights[pos_inds] = 1.0 return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
def _mask_point_onnx_export(self, x, rois, label_pred, mask_pred): """Export mask refining process with point head to onnx. Args: x (tuple[Tensor]): Feature maps of all scale level. rois (Tensor): shape (num_rois, 5). label_pred (Tensor): The predication class for each rois. mask_pred (Tensor): The predication coarse masks of shape (num_rois, num_classes, small_size, small_size). Returns: Tensor: The refined masks of shape (num_rois, num_classes, large_size, large_size). """ refined_mask_pred = mask_pred.clone() for subdivision_step in range(self.test_cfg.subdivision_steps): refined_mask_pred = F.interpolate( refined_mask_pred, scale_factor=self.test_cfg.scale_factor, mode='bilinear', align_corners=False) # If `subdivision_num_points` is larger or equal to the # resolution of the next step, then we can skip this step num_rois, channels, mask_height, mask_width = \ refined_mask_pred.shape if (self.test_cfg.subdivision_num_points >= self.test_cfg.scale_factor**2 * mask_height * mask_width and subdivision_step < self.test_cfg.subdivision_steps - 1): continue point_indices, rel_roi_points = \ self.point_head.get_roi_rel_points_test( refined_mask_pred, label_pred, cfg=self.test_cfg) fine_grained_point_feats = self._onnx_get_fine_grained_point_feats( x, rois, rel_roi_points) coarse_point_feats = point_sample(mask_pred, rel_roi_points) mask_point_pred = self.point_head(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_mask_pred = refined_mask_pred.reshape( num_rois, channels, mask_height * mask_width) is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT' # avoid ScatterElements op in ONNX for TensorRT if is_trt_backend: mask_shape = refined_mask_pred.shape point_shape = point_indices.shape inds_dim0 = torch.arange(point_shape[0]).reshape( point_shape[0], 1, 1).expand_as(point_indices) inds_dim1 = torch.arange(point_shape[1]).reshape( 1, point_shape[1], 1).expand_as(point_indices) inds_1d = inds_dim0.reshape( -1) * mask_shape[1] * mask_shape[2] + inds_dim1.reshape( -1) * mask_shape[2] + point_indices.reshape(-1) refined_mask_pred = refined_mask_pred.reshape(-1) refined_mask_pred[inds_1d] = mask_point_pred.reshape(-1) refined_mask_pred = refined_mask_pred.reshape(*mask_shape) else: refined_mask_pred = refined_mask_pred.scatter_( 2, point_indices, mask_point_pred) refined_mask_pred = refined_mask_pred.view(num_rois, channels, mask_height, mask_width) return refined_mask_pred