def inference_single_image(self, cate_preds, seg_preds, featmap_size, img_shape, ori_shape): """ Args: cate_preds, seg_preds: see: method: `inference`. featmap_size (list[tuple]): feature map size per level. img_shape (tuple): the size of the image fed into the model (height and width). ori_shape (tuple): original image shape (height and width). Returns: result (Instances): predicted results of single image after post-processing. """ assert len(cate_preds) == len(seg_preds) result = Instances(ori_shape) # overall info. h, w = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # process. inds = (cate_preds > self.score_threshold) # category scores. cate_scores = cate_preds[inds] if len(cate_scores) == 0: return result # category labels. inds = inds.nonzero(as_tuple=False) cate_labels = inds[:, 1] # strides. size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum( 0) # [1600, 2896, 3472, 3728, 3872] strides = cate_scores.new_ones(size_trans[-1]) n_stage = len(self.seg_num_grids) strides[:size_trans[0]] *= self.feature_strides[0] for ind_ in range(1, n_stage): strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.feature_strides[ ind_] strides = strides[inds[:, 0]] # masks. seg_preds = seg_preds[inds[:, 0]] seg_masks = seg_preds > self.mask_threshold sum_masks = seg_masks.sum((1, 2)).float() # filter. keep = sum_masks > strides if keep.sum() == 0: return result seg_masks = seg_masks[keep, ...] seg_preds = seg_preds[keep, ...] sum_masks = sum_masks[keep] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # mask scoring. seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks cate_scores *= seg_scores # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.nms_per_image: sort_inds = sort_inds[:self.nms_per_image] seg_masks = seg_masks[sort_inds, :, :] seg_preds = seg_preds[sort_inds, :, :] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=self.nms_kernel, sigma=self.nms_sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= self.update_threshold if keep.sum() == 0: return result seg_preds = seg_preds[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.max_detections_per_image: sort_inds = sort_inds[:self.max_detections_per_image] seg_preds = seg_preds[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_preds = F.interpolate(seg_preds.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_preds, size=ori_shape, mode='bilinear').squeeze(0) seg_masks = seg_masks > self.mask_threshold seg_masks = BitMasks(seg_masks) result.pred_masks = seg_masks result.pred_boxes = seg_masks.get_bounding_boxes() result.scores = cate_scores result.pred_classes = cate_labels return result
def inference_single_image(self, cate_preds, seg_preds_x, seg_preds_y, featmap_size, img_shape, ori_shape): result = Instances(ori_shape) # overall info. h, w = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # trans trans_diff. trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long() trans_diff = torch.ones(trans_size[-1].item(), device=self.device).long() num_grids = torch.ones(trans_size[-1].item(), device=self.device).long() seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long() seg_diff = torch.ones(trans_size[-1].item(), device=self.device).long() strides = torch.ones(trans_size[-1].item(), device=self.device) n_stage = len(self.seg_num_grids) trans_diff[:trans_size[0]] *= 0 seg_diff[:trans_size[0]] *= 0 num_grids[:trans_size[0]] *= self.seg_num_grids[0] strides[:trans_size[0]] *= self.feature_strides[0] for ind_ in range(1, n_stage): trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1] seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1] num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ ind_] strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.feature_strides[ ind_] # process. inds = (cate_preds > self.score_threshold) # category scores. cate_scores = cate_preds[inds] # category labels. inds = inds.nonzero(as_tuple=False) trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0]) seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0]) num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0]) strides = torch.index_select(strides, dim=0, index=inds[:, 0]) y_inds = (inds[:, 0] - trans_diff) // num_grids x_inds = (inds[:, 0] - trans_diff) % num_grids y_inds += seg_diff x_inds += seg_diff cate_labels = inds[:, 1] seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...] seg_masks = seg_masks_soft > self.mask_threshold sum_masks = seg_masks.sum((1, 2)).float() # filter. keep = sum_masks > strides if keep.sum() == 0: return result seg_masks_soft = seg_masks_soft[keep, ...] seg_masks = seg_masks[keep, ...] cate_scores = cate_scores[keep] sum_masks = sum_masks[keep] cate_labels = cate_labels[keep] # mask scoring seg_score = (seg_masks_soft * seg_masks.float()).sum( (1, 2)) / sum_masks cate_scores *= seg_score if len(cate_scores) == 0: return result # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.nms_per_image: sort_inds = sort_inds[:self.nms_per_image] seg_masks_soft = seg_masks_soft[sort_inds, :, :] seg_masks = seg_masks[sort_inds, :, :] cate_scores = cate_scores[sort_inds] sum_masks = sum_masks[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=self.nms_kernel, sigma=self.nms_sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= self.update_threshold seg_masks_soft = seg_masks_soft[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.max_detections_per_image: sort_inds = sort_inds[:self.max_detections_per_image] seg_masks_soft = seg_masks_soft[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_masks_soft, size=ori_shape, mode='bilinear').squeeze(0) seg_masks = seg_masks > self.mask_threshold seg_masks = BitMasks(seg_masks) result.pred_masks = seg_masks result.pred_boxes = seg_masks.get_bounding_boxes() result.scores = cate_scores result.pred_classes = cate_labels return result