def mask_heads_forward_with_coords( self, mask_feats, mask_feat_stride, instances ): # 1/8 P3 对应的原图真实坐标 locations = compute_locations( mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device ) n_inst = len(instances) # 在fcos每个点采样时,会记录所属img id,最后根据分类pos id筛选, # 最后在实例分割这块根据ins个数复制相应image出来的mask feature # 即(2, n, h, w) -> (ins, n, hxw) im_inds = instances.im_inds mask_head_params = instances.mask_head_params N, _, H, W = mask_feats.size() if not self.disable_rel_coords: # 之前在fcos记录了每个pos location的中心位置,在这里生成相对坐标 # 即target所在的点变成0,其余的变成和它的相对距离 instance_locations = instances.locations # (39, 1, 2) - (1, hxw, 2) relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) # (39, hxw, 2) --> (39, 2, hxw) relative_coords = relative_coords.permute(0, 2, 1).float() # 给每个相对距离乘以一个衰减系数,如果instance越大,即来自高层特征, # 则会给它的相对距离更大的衰减因子 soi = self.sizes_of_interest.float()[instances.fpn_levels] relative_coords = relative_coords / soi.reshape(-1, 1, 1) relative_coords = relative_coords.to(dtype=mask_feats.dtype) mask_head_inputs = torch.cat([ relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) ], dim=1) # torch.Size([39, 10, 12880]) torch.Size([2, 8, 92, 140]) torch.Size([39, 2, 12880]) # print(mask_head_inputs.shape, mask_feats.shape, relative_coords.shape) else: mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) weights, biases = parse_dynamic_params( mask_head_params, self.channels, self.weight_nums, self.bias_nums ) # 现在的mask_head_inputs的每个instance包含了全图的mask feature # 以及所预测instance中心点的相对距离信息 # torch.Size([1, 580, 100, 136]) # [torch.Size([464, 10, 1, 1]), torch.Size([464, 8, 1, 1]), torch.Size([58, 8, 1, 1])] # [torch.Size([464]), torch.Size([464]), torch.Size([58])] mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) mask_logits = mask_logits.reshape(-1, 1, H, W) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) return mask_logits.sigmoid()
def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5): """ Resize the output instances. The input images are often resized when entering an object detector. As a result, we often need the outputs of the detector in a different resolution from its inputs. This function will resize the raw outputs of an R-CNN detector to produce outputs according to the desired output resolution. Args: results (Instances): the raw outputs from the detector. `results.image_size` contains the input image resolution the detector sees. This object might be modified in-place. output_height, output_width: the desired output resolution. Returns: Instances: the resized output from the model, based on the output resolution """ scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) resized_im_h, resized_im_w = results.image_size results = Instances((output_height, output_width), **results.get_fields()) if results.has("pred_boxes"): output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) results = results[output_boxes.nonempty()] if results.has("pred_global_masks"): mask_h, mask_w = results.pred_global_masks.size()[-2:] factor_h = padded_im_h // mask_h factor_w = padded_im_w // mask_w assert factor_h == factor_w factor = factor_h pred_global_masks = aligned_bilinear(results.pred_global_masks, factor) pred_global_masks = pred_global_masks[:, :, :resized_im_h, : resized_im_w] pred_global_masks = F.interpolate(pred_global_masks, size=(output_height, output_width), mode="bilinear", align_corners=False) pred_global_masks = pred_global_masks[:, 0, :, :] results.pred_masks = (pred_global_masks > mask_threshold).float() return results
def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride, instances): locations = compute_locations(mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device) n_inst = len(instances) im_inds = instances.im_inds mask_head_params = instances.mask_head_params N, _, H, W = mask_feats.size() if not self.disable_rel_coords: instance_locations = instances.locations relative_coords = instance_locations.reshape( -1, 1, 2) - locations.reshape(1, -1, 2) relative_coords = relative_coords.permute(0, 2, 1).float() soi = self.sizes_of_interest.float()[instances.fpn_levels] relative_coords = relative_coords / soi.reshape(-1, 1, 1) relative_coords = relative_coords.to(dtype=mask_feats.dtype) mask_head_inputs = torch.cat([ relative_coords, mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) ], dim=1) else: mask_head_inputs = mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) weights, biases = parse_dynamic_params(mask_head_params, self.channels, self.weight_nums, self.bias_nums) mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) mask_logits = mask_logits.reshape(-1, 1, H, W) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_logits = aligned_bilinear( mask_logits, int(mask_feat_stride / self.mask_out_stride)) return mask_logits
def compute_mask_prob(self, instances, pixel_embed, mask_feat_stride): proposal_embed = instances.proposal_embed proposal_margin = instances.proposal_margin im_inds = instances.im_inds dim, m_h, m_w = pixel_embed.shape[-3:] obj_num = proposal_embed.shape[0] pixel_embed = pixel_embed.permute(0, 2, 3, 1)[im_inds] proposal_embed = proposal_embed.view(obj_num, 1, 1, -1).expand(-1, m_h, m_w, -1) proposal_margin = proposal_margin.view(obj_num, 1, 1, dim).expand(-1, m_h, m_w, -1) mask_var = (pixel_embed - proposal_embed) ** 2 mask_prob = torch.exp(-torch.sum(mask_var * proposal_margin, dim=3)) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_prob = aligned_bilinear(mask_prob.unsqueeze(1), int(mask_feat_stride / self.mask_out_stride)) return mask_prob
def recover_ins2all_test(self, mask_scores, pred_instances): mask_scores = aligned_bilinear(mask_scores, 2) return mask_scores
def forward(self, features, gt_instances=None): for i, f in enumerate(self.in_features): if i == 0: x = self.refine[i](features[f]) else: x_p = self.refine[i](features[f]) target_h, target_w = x.size()[2:] h, w = x_p.size()[2:] assert target_h % h == 0 assert target_w % w == 0 factor_h, factor_w = target_h // h, target_w // w assert factor_h == factor_w x_p = aligned_bilinear(x_p, factor_h) x = x + x_p mask_feats = self.tower(x) losses = {} # auxiliary thing semantic loss if self.training and self.sem_loss_on: logits_pred = self.logits( self.seg_head(features[self.in_features[0]])) # compute semantic targets semantic_targets = [] for per_im_gt in gt_instances: h, w = per_im_gt.gt_bitmasks_full.size()[-2:] areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1) areas = areas[:, None, None].repeat(1, h, w) areas[per_im_gt.gt_bitmasks_full == 0] = INF areas = areas.permute(1, 2, 0).reshape(h * w, -1) min_areas, inds = areas.min(dim=1) per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1 per_im_sematic_targets[min_areas == INF] = 0 per_im_sematic_targets = per_im_sematic_targets.reshape(h, w) semantic_targets.append(per_im_sematic_targets) semantic_targets = torch.stack(semantic_targets, dim=0) # resize target to reduce memory semantic_targets = semantic_targets[:, None, self.out_stride // 2::self.out_stride, self.out_stride // 2::self.out_stride] # prepare one-hot targets num_classes = logits_pred.size(1) class_range = torch.arange(num_classes, dtype=logits_pred.dtype, device=logits_pred.device)[:, None, None] class_range = class_range + 1 one_hot = (semantic_targets == class_range).float() num_pos = (one_hot > 0).sum().float().clamp(min=1.0) loss_sem = sigmoid_focal_loss_jit( logits_pred, one_hot, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos losses['loss_sem'] = loss_sem return mask_feats, losses