def get_seg_single(self, cate_scores, cate_labels, seg_preds, attention_maps, strides, featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale=False, debug=False): # overall info. h, w, _ = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) seg_preds = seg_preds[:, 0] attention_maps = attention_maps[:, 0] # masks. seg_masks = seg_preds > cfg.mask_thr sum_masks = seg_masks.sum((1, 2)).float() # filter. keep = sum_masks > strides if keep.sum() == 0: return None seg_masks = seg_masks[keep, ...] seg_preds = seg_preds[keep, ...] attention_maps = attention_maps[keep, ...] sum_masks = sum_masks[keep] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # mask scoring. seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks cate_scores *= seg_scores # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.nms_pre: sort_inds = sort_inds[:cfg.nms_pre] seg_masks = seg_masks[sort_inds, :, :] seg_preds = seg_preds[sort_inds, :, :] attention_maps = attention_maps[sort_inds, ...] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= cfg.update_thr if keep.sum() == 0: return None seg_preds = seg_preds[keep, :, :] attention_maps = attention_maps[keep, ...] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.max_per_img: sort_inds = sort_inds[:cfg.max_per_img] seg_preds = seg_preds[sort_inds, :, :] attention_maps = attention_maps[sort_inds, ...] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_preds = F.interpolate(seg_preds.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) seg_masks = seg_masks > cfg.mask_thr attention_maps = F.interpolate(attention_maps.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] attention_masks = F.interpolate(attention_maps, size=ori_shape[:2], mode='bilinear').squeeze(0) attention_masks = attention_masks > 0 return seg_masks, cate_labels, cate_scores
def get_seg_single(self, cate_preds, seg_preds, featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale=False, debug=False): assert len(cate_preds) == len(seg_preds) # overall info. h, w, _ = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # process. inds = (cate_preds > cfg.score_thr) # category scores. cate_scores = cate_preds[inds] if len(cate_scores) == 0: return None # category labels. inds = inds.nonzero() cate_labels = inds[:, 1] # strides. size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0) strides = cate_scores.new_ones(size_trans[-1]) n_stage = len(self.seg_num_grids) strides[:size_trans[0]] *= self.strides[0] for ind_ in range(1, n_stage): strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_] strides = strides[inds[:, 0]] # masks. seg_preds = seg_preds[inds[:, 0]] seg_masks = seg_preds > cfg.mask_thr sum_masks = seg_masks.sum((1, 2)).float() # filter. keep = sum_masks > strides if keep.sum() == 0: return None seg_masks = seg_masks[keep, ...] seg_preds = seg_preds[keep, ...] sum_masks = sum_masks[keep] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # mask scoring. seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks cate_scores *= seg_scores # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.nms_pre: sort_inds = sort_inds[:cfg.nms_pre] seg_masks = seg_masks[sort_inds, :, :] seg_preds = seg_preds[sort_inds, :, :] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= cfg.update_thr if keep.sum() == 0: return None seg_preds = seg_preds[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.max_per_img: sort_inds = sort_inds[:cfg.max_per_img] seg_preds = seg_preds[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_preds = F.interpolate(seg_preds.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) seg_masks = seg_masks > cfg.mask_thr return seg_masks, cate_labels, cate_scores
def get_seg_single( self, cls_scores, # (5)[h/s_i*w/s_i, 80] fcn_params, # (5)[h/s_i*w/s_i, 169] mask_feat_pred, # (1, 8, h/8, w/8) mlvl_points, img_shape, ori_shape, scale_factor, cfg, rescale=False): assert len(cls_scores) == len(fcn_params) featmap_size = mask_feat_pred.size()[-2:] # 100,152 upsampled_size_out = (featmap_size[0] * self.mask_downsample, featmap_size[1] * self.mask_downsample) H, W, _ = img_shape cls_scores = torch.cat(cls_scores, dim=0) fcn_params = torch.cat(fcn_params, dim=0) inds = (cls_scores > cfg.score_thr) cate_scores = cls_scores[inds] if len(cate_scores) == 0: return None inds = inds.nonzero() # [total_point, 2] row col cate_labels = inds[:, 1] #[n] param_preds = fcn_params[inds[:, 0]] #[n, 169] # forward weight, bias = self.parse_dynamic_params(param_preds) mask_feat_pred = mask_feat_pred.repeat((param_preds.shape[0], 1, 1, 1)) mask_feat_pred = mask_feat_pred.reshape(1, -1, featmap_size[0], featmap_size[1]) for i, (w, b) in enumerate(zip(weight, bias)): mask_feat_pred = F.conv2d(mask_feat_pred, w, bias=b, stride=1, padding=0, groups=param_preds.shape[0]) if i < len(weight) - 1: mask_feat_pred = F.relu(mask_feat_pred) # mask_feat_pred : (1, num_pos, H, W) -> (num_pos, 1, 2H, 2W) mask_feat_pred = mask_feat_pred.reshape(-1, 1, featmap_size[0], featmap_size[1]) mask_logits = aligned_bilinear( mask_feat_pred, int(self.mask_downsample / self.mask_out_stride)) mask_logits.sigmoid() mask_logits = mask_logits.permute(1, 0, 2, 3).squeeze(0) # (num_pos, H, W) seg_masks = mask_logits > cfg.mask_thr # mask score sum_masks = seg_masks.sum((1, 2)).float() # remove 0 ind_valid = sum_masks > 0 mask_logits = mask_logits[ind_valid] seg_masks = seg_masks[ind_valid] sum_masks = sum_masks[ind_valid] cate_scores = cate_scores[ind_valid] cate_labels = cate_labels[ind_valid] seg_scores = (mask_logits * seg_masks.float()).sum((1, 2)) / sum_masks cate_scores *= seg_scores # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.nms_pre: sort_inds = sort_inds[:cfg.nms_pre] seg_masks = seg_masks[sort_inds, :, :] mask_logits = mask_logits[sort_inds, :, :] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= cfg.update_thr if keep.sum() == 0: return None mask_logits = mask_logits[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.max_per_img: sort_inds = sort_inds[:cfg.max_per_img] mask_logits = mask_logits[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] mask_logits = F.interpolate(mask_logits.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :H, :W] seg_masks = F.interpolate(mask_logits, size=ori_shape[:2], mode='bilinear').squeeze(0) seg_masks = seg_masks > cfg.mask_thr return seg_masks, cate_labels, cate_scores
def get_seg_single(self, cate_preds, seg_preds_x, seg_preds_y, featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale=False, debug=False): # overall info. h, w, _ = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # trans trans_diff. trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long() trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long() num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long() seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long() seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long() strides = torch.ones(trans_size[-1].item(), device=cate_preds.device) n_stage = len(self.seg_num_grids) trans_diff[:trans_size[0]] *= 0 seg_diff[:trans_size[0]] *= 0 num_grids[:trans_size[0]] *= self.seg_num_grids[0] strides[:trans_size[0]] *= self.strides[0] for ind_ in range(1, n_stage): trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1] seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1] num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ ind_] strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_] # process. inds = (cate_preds > cfg.score_thr) cate_scores = cate_preds[inds] inds = inds.nonzero() trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0]) seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0]) num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0]) strides = torch.index_select(strides, dim=0, index=inds[:, 0]) y_inds = (inds[:, 0] - trans_diff) // num_grids x_inds = (inds[:, 0] - trans_diff) % num_grids y_inds += seg_diff x_inds += seg_diff cate_labels = inds[:, 1] seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...] seg_masks = seg_masks_soft > cfg.mask_thr sum_masks = seg_masks.sum((1, 2)).float() keep = sum_masks > strides seg_masks_soft = seg_masks_soft[keep, ...] seg_masks = seg_masks[keep, ...] cate_scores = cate_scores[keep] sum_masks = sum_masks[keep] cate_labels = cate_labels[keep] # maskness seg_score = (seg_masks_soft * seg_masks.float()).sum( (1, 2)) / sum_masks cate_scores *= seg_score if len(cate_scores) == 0: return None # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.nms_pre: sort_inds = sort_inds[:cfg.nms_pre] seg_masks_soft = seg_masks_soft[sort_inds, :, :] seg_masks = seg_masks[sort_inds, :, :] cate_scores = cate_scores[sort_inds] sum_masks = sum_masks[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks) keep = cate_scores >= cfg.update_thr seg_masks_soft = seg_masks_soft[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.max_per_img: sort_inds = sort_inds[:cfg.max_per_img] seg_masks_soft = seg_masks_soft[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_masks_soft, size=ori_shape[:2], mode='bilinear').squeeze(0) seg_masks = seg_masks > cfg.mask_thr return seg_masks, cate_labels, cate_scores
def aug_test(self, imgs, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned masks will fit the scale of imgs[0]. """ ori_shape = img_metas[0][0]['ori_shape'][:2] meta_result_list = [] for img, img_meta in zip(imgs, img_metas): x = self.extract_feat(img) seg_preds, cate_preds = self.bbox_head(x, eval=True) img_shape = img_meta[0]['img_shape'] img_result_list = self.bbox_head.get_seg_aug( seg_preds, cate_preds, img_shape, self.test_cfg) meta_result_list.append(img_result_list) img_output = [] for img_result in zip(*meta_result_list): seg_masks, seg_preds, sum_masks, cate_scores, cate_labels = map( list, zip(*img_result)) unified_size = tuple(seg_masks[0].shape[-2:]) for i in range(1, len(seg_masks)): seg_masks[i] = F.interpolate(seg_masks[i].float().unsqueeze(0), size=unified_size, mode='bilinear', align_corners=False).squeeze(0) seg_preds[i] = F.interpolate(seg_preds[i].unsqueeze(0), size=unified_size, mode='bilinear', align_corners=False).squeeze(0) if img_metas[i][0]['flip']: seg_masks[i] = torch.flip(seg_masks[i], dims=[2]) seg_preds[i] = torch.flip(seg_preds[i], dims=[2]) seg_masks = torch.cat(seg_masks, dim=0) seg_preds = torch.cat(seg_preds, dim=0) sum_masks = torch.cat(sum_masks, dim=0) cate_scores = torch.cat(cate_scores, dim=0) cate_labels = torch.cat(cate_labels, dim=0) # import cv2 # for i, seg_mask in enumerate(seg_masks): # cv2.imwrite('/versa/dyy/SOLO/tta/{}.png'.format(i), # seg_mask.cpu().numpy().astype(np.uint8) * 255) # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.test_cfg.nms_pre: sort_inds = sort_inds[:self.test_cfg.nms_pre] seg_masks = seg_masks[sort_inds, :, :] seg_preds = seg_preds[sort_inds, :, :] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=self.test_cfg.kernel, sigma=self.test_cfg.sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= self.test_cfg.update_thr if keep.sum() == 0: return None seg_preds = seg_preds[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > self.test_cfg.max_per_img: sort_inds = sort_inds[:self.test_cfg.max_per_img] seg_preds = seg_preds[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_masks = F.interpolate(seg_preds.unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False).squeeze(0) seg_masks = seg_masks > self.test_cfg.mask_thr output = (seg_masks, cate_labels, cate_scores) img_output.append(output) return img_output
def get_seg_single_threshed(self, cate_preds, seg_preds, cate_labels, featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale=False, debug=False): assert len(cate_preds) == len(seg_preds) # overall info. h, w, _ = img_shape upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # process. # inds = (cate_preds > cfg.score_thr) # category scores. cate_scores = cate_preds if len(cate_scores) == 0: return None # masks. seg_masks = seg_preds > cfg.mask_thr sum_masks = seg_masks.sum((1, 2)).float() # mask scoring. ## average confidence on mask area seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks cate_scores *= seg_scores # sort and keep top nms_pre sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.nms_pre: sort_inds = sort_inds[:cfg.nms_pre] seg_masks = seg_masks[sort_inds, :, :] seg_preds = seg_preds[sort_inds, :, :] sum_masks = sum_masks[sort_inds] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] # Matrix NMS cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks) # filter. keep = cate_scores >= cfg.update_thr if keep.sum() == 0: return None seg_preds = seg_preds[keep, :, :] cate_scores = cate_scores[keep] cate_labels = cate_labels[keep] # sort and keep top_k sort_inds = torch.argsort(cate_scores, descending=True) if len(sort_inds) > cfg.max_per_img: sort_inds = sort_inds[:cfg.max_per_img] seg_preds = seg_preds[sort_inds, :, :] cate_scores = cate_scores[sort_inds] cate_labels = cate_labels[sort_inds] seg_preds = F.interpolate(seg_preds.unsqueeze(0), size=upsampled_size_out, mode='bilinear')[:, :, :h, :w] seg_masks = F.interpolate(seg_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) seg_masks = seg_masks > cfg.mask_thr return seg_masks, cate_labels, cate_scores